Keras使用ImageNet上预训练的模型方式


Posted in Python onMay 23, 2020

我就废话不多说了,大家还是直接看代码吧!

import keras
import numpy as np
from keras.applications import vgg16, inception_v3, resnet50, mobilenet
 
#Load the VGG model
vgg_model = vgg16.VGG16(weights='imagenet')
 
#Load the Inception_V3 model
inception_model = inception_v3.InceptionV3(weights='imagenet')
 
#Load the ResNet50 model
resnet_model = resnet50.ResNet50(weights='imagenet')
 
#Load the MobileNet model
mobilenet_model = mobilenet.MobileNet(weights='imagenet')

在以上代码中,我们首先import各种模型对应的module,然后load模型,并用ImageNet的参数初始化模型的参数。

如果不想使用ImageNet上预训练到的权重初始话模型,可以将各语句的中'imagenet'替换为'None'。

补充知识:keras上使用alexnet模型来高准确度对mnist数据进行分类

纲要

本文有两个特点:一是直接对本地mnist数据进行读取(假设事先已经下载或从别处拷来)二是基于keras框架(网上多是基于tf)使用alexnet对mnist数据进行分类,并获得较高准确度(约为98%)

本地数据读取和分析

很多代码都是一开始简单调用一行代码来从网站上下载mnist数据,虽然只有10来MB,但是现在下载速度非常慢,而且经常中途出错,要费很大的劲才能拿到数据。

(X_train, y_train), (X_test, y_test) = mnist.load_data()

其实可以单独来获得这些数据(一共4个gz包,如下所示),然后调用别的接口来分析它们。

Keras使用ImageNet上预训练的模型方式

mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录

x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels

这里面要注意的是,两种接口拿到的数据形式是不一样的。 从网上直接下载下来的数据 其image data值的范围是0~255,且label值为0,1,2,3...9。 而第二种接口获取的数据 image值已经除以255(归一化)变成0~1范围,且label值已经是one-hot形式(one_hot=True时),比如label值2的one-hot code为(0 0 1 0 0 0 0 0 0 0)

所以,以第一种方式获取的数据需要做一些预处理(归一和one-hot)才能输入网络模型进行训练 而第二种接口拿到的数据则可以直接进行训练。

Alexnet模型的微调

按照公开的模型框架,Alexnet只有第1、2个卷积层才跟着BatchNormalization,后面三个CNN都没有(如有说错,请指正)。如果按照这个来搭建网络模型,很容易导致梯度消失,现象就是 accuracy值一直处在很低的值。 如下所示。

Keras使用ImageNet上预训练的模型方式

在每个卷积层后面都加上BN后,准确度才迭代提高。如下所示

Keras使用ImageNet上预训练的模型方式

完整代码

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #tensorflow已经包含了mnist案例的数据
 
batch_size = 64
num_classes = 10
epochs = 10
img_shape = (28,28,1)
 
# input dimensions
img_rows, img_cols = 28,28
 
# dataset input
#(x_train, y_train), (x_test, y_test) = mnist.load_data()
mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)
 
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels
 
# data initialization
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
 
# Define the input layer
inputs = keras.Input(shape = [img_rows, img_cols, 1])
 
 #Define the converlutional layer 1
conv1 = keras.layers.Conv2D(filters= 64, kernel_size= [11, 11], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(inputs)
# Define the pooling layer 1
pooling1 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv1)
# Define the standardization layer 1
stand1 = keras.layers.BatchNormalization(axis= 1)(pooling1)
 
# Define the converlutional layer 2
conv2 = keras.layers.Conv2D(filters= 192, kernel_size= [5, 5], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand1)
# Defien the pooling layer 2
pooling2 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv2)
# Define the standardization layer 2
stand2 = keras.layers.BatchNormalization(axis= 1)(pooling2)
 
# Define the converlutional layer 3
conv3 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand2)
stand3 = keras.layers.BatchNormalization(axis=1)(conv3)
 
# Define the converlutional layer 4
conv4 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand3)
stand4 = keras.layers.BatchNormalization(axis=1)(conv4)
 
# Define the converlutional layer 5
conv5 = keras.layers.Conv2D(filters= 256, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand4)
pooling5 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv5)
stand5 = keras.layers.BatchNormalization(axis=1)(pooling5)
 
# Define the fully connected layer
flatten = keras.layers.Flatten()(stand5)
fc1 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(flatten)
drop1 = keras.layers.Dropout(0.5)(fc1)
 
fc2 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(drop1)
drop2 = keras.layers.Dropout(0.5)(fc2)
 
fc3 = keras.layers.Dense(10, activation= keras.activations.softmax, use_bias= True)(drop2)
 
# 基于Model方法构建模型
model = keras.Model(inputs= inputs, outputs = fc3)
# 编译模型
model.compile(optimizer= tf.train.AdamOptimizer(0.001),
       loss= keras.losses.categorical_crossentropy,
       metrics= ['accuracy'])
# 训练配置,仅供参考
model.fit(x_train, y_train, batch_size= batch_size, epochs= epochs, validation_data=(x_test,y_test))

以上这篇Keras使用ImageNet上预训练的模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 2.7.x 和 3.x 版本的重要区别小结
Nov 28 Python
Python编写生成验证码的脚本的教程
May 04 Python
Python处理JSON数据并生成条形图
Aug 05 Python
python数据结构之链表的实例讲解
Jul 25 Python
Django REST为文件属性输出完整URL的方法
Dec 18 Python
Python数学形态学实例分析
Sep 06 Python
Pandas聚合运算和分组运算的实现示例
Oct 17 Python
Python实现打印实心和空心菱形
Nov 23 Python
python为Django项目上的每个应用程序创建不同的自定义404页面(最佳答案)
Mar 09 Python
python中def是做什么的
Jun 10 Python
python 日志模块logging的使用场景及示例
Jan 04 Python
Python如何解决secure_filename对中文不支持问题
Jul 16 Python
使用Keras预训练模型ResNet50进行图像分类方式
May 23 #Python
基于Python中random.sample()的替代方案
May 23 #Python
keras 自定义loss损失函数,sample在loss上的加权和metric详解
May 23 #Python
keras中模型训练class_weight,sample_weight区别说明
May 23 #Python
浅谈keras中的Merge层(实现层的相加、相减、相乘实例)
May 23 #Python
Keras实现将两个模型连接到一起
May 23 #Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
You might like
php实现数组中索引关联数据转换成json对象的方法
2015/07/08 PHP
从ThinkPHP3.2.3过渡到ThinkPHP5.0学习笔记图文详解
2019/04/03 PHP
js实现幻灯片播放图片示例代码
2013/11/07 Javascript
html5 canvas js(数字时钟)实例代码
2013/12/23 Javascript
jQuery中val()方法用法实例
2014/12/25 Javascript
JS交换变量的方法
2015/01/21 Javascript
jQuery实现购物车数字加减效果
2015/03/14 Javascript
详解微信小程序设置底部导航栏目方法
2017/06/29 Javascript
JS实现微信摇一摇原理解析
2017/07/22 Javascript
JavaScript中Object基础内部方法图
2018/02/05 Javascript
jQuery实现碰到边缘反弹的动画效果
2018/02/24 jQuery
微信小程序实现自动定位功能
2018/10/31 Javascript
一些手写JavaScript常用的函数汇总
2019/04/16 Javascript
vue2之简易的pc端短信验证码的问题及处理方法
2019/06/03 Javascript
LayUi使用switch开关,动态的去控制它是否被启用的方法
2019/09/21 Javascript
JS插件amCharts实现绘制柱形图默认显示数值功能示例
2019/11/26 Javascript
vue中的v-model原理,与组件自定义v-model详解
2020/08/04 Javascript
原生JS生成指定位数的验证码
2020/10/28 Javascript
Javascript新手入门之字符串拼接与变量的应用
2020/12/03 Javascript
Python程序中设置HTTP代理
2016/11/06 Python
python脚本生成caffe train_list.txt的方法
2018/04/27 Python
Tensorflow卷积神经网络实例进阶
2018/05/24 Python
Python 网络爬虫--关于简单的模拟登录实例讲解
2018/06/01 Python
关于python pycharm中输出的内容不全的解决办法
2020/01/10 Python
python框架Django实战商城项目之工程搭建过程图文详解
2020/03/09 Python
Hertz荷兰:荷兰和全球租车
2018/01/07 全球购物
Cult Gaia官网:美国生活方式品牌
2019/08/16 全球购物
澳大利亚最受欢迎的女士度假服装:Kabana Shop
2020/10/10 全球购物
索引覆盖(Index Covering)查询含义
2012/02/18 面试题
什么是抽象
2015/12/13 面试题
公司综合部的成员自我评价分享
2013/11/05 职场文书
测绘工程个人的自我评价
2013/11/10 职场文书
互联网电子商务专业毕业生求职信
2014/03/18 职场文书
python和anaconda的区别
2022/05/06 Python
SQL bool盲注和时间盲注详解
2022/07/23 SQL Server
CSS使用Flex和Grid布局实现3D骰子
2022/08/05 HTML / CSS