Keras使用ImageNet上预训练的模型方式


Posted in Python onMay 23, 2020

我就废话不多说了,大家还是直接看代码吧!

import keras
import numpy as np
from keras.applications import vgg16, inception_v3, resnet50, mobilenet
 
#Load the VGG model
vgg_model = vgg16.VGG16(weights='imagenet')
 
#Load the Inception_V3 model
inception_model = inception_v3.InceptionV3(weights='imagenet')
 
#Load the ResNet50 model
resnet_model = resnet50.ResNet50(weights='imagenet')
 
#Load the MobileNet model
mobilenet_model = mobilenet.MobileNet(weights='imagenet')

在以上代码中,我们首先import各种模型对应的module,然后load模型,并用ImageNet的参数初始化模型的参数。

如果不想使用ImageNet上预训练到的权重初始话模型,可以将各语句的中'imagenet'替换为'None'。

补充知识:keras上使用alexnet模型来高准确度对mnist数据进行分类

纲要

本文有两个特点:一是直接对本地mnist数据进行读取(假设事先已经下载或从别处拷来)二是基于keras框架(网上多是基于tf)使用alexnet对mnist数据进行分类,并获得较高准确度(约为98%)

本地数据读取和分析

很多代码都是一开始简单调用一行代码来从网站上下载mnist数据,虽然只有10来MB,但是现在下载速度非常慢,而且经常中途出错,要费很大的劲才能拿到数据。

(X_train, y_train), (X_test, y_test) = mnist.load_data()

其实可以单独来获得这些数据(一共4个gz包,如下所示),然后调用别的接口来分析它们。

Keras使用ImageNet上预训练的模型方式

mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录

x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels

这里面要注意的是,两种接口拿到的数据形式是不一样的。 从网上直接下载下来的数据 其image data值的范围是0~255,且label值为0,1,2,3...9。 而第二种接口获取的数据 image值已经除以255(归一化)变成0~1范围,且label值已经是one-hot形式(one_hot=True时),比如label值2的one-hot code为(0 0 1 0 0 0 0 0 0 0)

所以,以第一种方式获取的数据需要做一些预处理(归一和one-hot)才能输入网络模型进行训练 而第二种接口拿到的数据则可以直接进行训练。

Alexnet模型的微调

按照公开的模型框架,Alexnet只有第1、2个卷积层才跟着BatchNormalization,后面三个CNN都没有(如有说错,请指正)。如果按照这个来搭建网络模型,很容易导致梯度消失,现象就是 accuracy值一直处在很低的值。 如下所示。

Keras使用ImageNet上预训练的模型方式

在每个卷积层后面都加上BN后,准确度才迭代提高。如下所示

Keras使用ImageNet上预训练的模型方式

完整代码

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #tensorflow已经包含了mnist案例的数据
 
batch_size = 64
num_classes = 10
epochs = 10
img_shape = (28,28,1)
 
# input dimensions
img_rows, img_cols = 28,28
 
# dataset input
#(x_train, y_train), (x_test, y_test) = mnist.load_data()
mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)
 
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels
 
# data initialization
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
 
# Define the input layer
inputs = keras.Input(shape = [img_rows, img_cols, 1])
 
 #Define the converlutional layer 1
conv1 = keras.layers.Conv2D(filters= 64, kernel_size= [11, 11], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(inputs)
# Define the pooling layer 1
pooling1 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv1)
# Define the standardization layer 1
stand1 = keras.layers.BatchNormalization(axis= 1)(pooling1)
 
# Define the converlutional layer 2
conv2 = keras.layers.Conv2D(filters= 192, kernel_size= [5, 5], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand1)
# Defien the pooling layer 2
pooling2 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv2)
# Define the standardization layer 2
stand2 = keras.layers.BatchNormalization(axis= 1)(pooling2)
 
# Define the converlutional layer 3
conv3 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand2)
stand3 = keras.layers.BatchNormalization(axis=1)(conv3)
 
# Define the converlutional layer 4
conv4 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand3)
stand4 = keras.layers.BatchNormalization(axis=1)(conv4)
 
# Define the converlutional layer 5
conv5 = keras.layers.Conv2D(filters= 256, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand4)
pooling5 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv5)
stand5 = keras.layers.BatchNormalization(axis=1)(pooling5)
 
# Define the fully connected layer
flatten = keras.layers.Flatten()(stand5)
fc1 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(flatten)
drop1 = keras.layers.Dropout(0.5)(fc1)
 
fc2 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(drop1)
drop2 = keras.layers.Dropout(0.5)(fc2)
 
fc3 = keras.layers.Dense(10, activation= keras.activations.softmax, use_bias= True)(drop2)
 
# 基于Model方法构建模型
model = keras.Model(inputs= inputs, outputs = fc3)
# 编译模型
model.compile(optimizer= tf.train.AdamOptimizer(0.001),
       loss= keras.losses.categorical_crossentropy,
       metrics= ['accuracy'])
# 训练配置,仅供参考
model.fit(x_train, y_train, batch_size= batch_size, epochs= epochs, validation_data=(x_test,y_test))

以上这篇Keras使用ImageNet上预训练的模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python os.path模块常用方法实例详解
Sep 16 Python
python组合无重复三位数的实例
Nov 13 Python
python3实现网络爬虫之BeautifulSoup使用详解
Dec 19 Python
python实现名片管理系统项目
Apr 26 Python
基于Python安装pyecharts所遇的问题及解决方法
Aug 12 Python
详解PyTorch手写数字识别(MNIST数据集)
Aug 16 Python
Python大数据之网络爬虫的post请求、get请求区别实例分析
Nov 16 Python
python中数据库like模糊查询方式
Mar 02 Python
python中selenium库的基本使用详解
Jul 31 Python
python处理写入数据代码讲解
Oct 22 Python
Python  Asyncio模块实现的生产消费者模型的方法
Mar 01 Python
python实现腾讯滑块验证码识别
Apr 27 Python
使用Keras预训练模型ResNet50进行图像分类方式
May 23 #Python
基于Python中random.sample()的替代方案
May 23 #Python
keras 自定义loss损失函数,sample在loss上的加权和metric详解
May 23 #Python
keras中模型训练class_weight,sample_weight区别说明
May 23 #Python
浅谈keras中的Merge层(实现层的相加、相减、相乘实例)
May 23 #Python
Keras实现将两个模型连接到一起
May 23 #Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
You might like
暴雪前总裁遗憾:没尽早追赶Dota 取消星际争霸幽灵
2020/03/08 星际争霸
Win2003下APACHE+PHP5+MYSQL4+PHPMYADMIN 的简易安装配置
2006/11/18 PHP
mysql 查询指定日期时间内sql语句实现原理与代码
2012/12/16 PHP
php下pdo的mysql事务处理用法实例
2014/12/27 PHP
PHP Swoole异步读取、写入文件操作示例
2019/10/24 PHP
PHP 数组操作详解【遍历、指针、函数等】
2020/05/13 PHP
jQuery ajax cache缓存问题
2010/07/01 Javascript
JS控制网页动态生成任意行列数表格的方法
2015/03/09 Javascript
artDialog+plupload实现多文件上传
2016/07/19 Javascript
JS验证 只能输入小数点,数字,负数的实现方法
2016/10/07 Javascript
Nodejs进阶:如何将图片转成datauri嵌入到网页中去实例
2016/11/21 NodeJs
jQuery Easyui加载表格出错时在表格中间显示自定义的提示内容
2016/12/08 Javascript
JsChart组件使用详解
2018/03/04 Javascript
深入理解JS的事件绑定、事件流模型
2018/05/13 Javascript
vue2.0+vuex+localStorage代办事项应用实现详解
2018/05/31 Javascript
vue cli2.0单页面title修改方法
2018/06/07 Javascript
javascript使用正则实现去掉字符串前面的所有0
2018/07/23 Javascript
JS实现把一个页面层数据传递到另一个页面的两种方式
2018/08/13 Javascript
在vue中多次调用同一个定义全局变量的实例
2018/09/25 Javascript
微信小程序整个页面的自动适应布局的实现
2020/07/12 Javascript
django+mysql的使用示例
2018/11/23 Python
利用Python模拟登录pastebin.com的实现方法
2019/07/12 Python
opencv实现简单人脸识别
2021/02/19 Python
Python3将jpg转为pdf文件的方法示例
2019/12/13 Python
Jupyter安装链接aconda实现过程图解
2020/11/02 Python
高级护理实习生自荐信
2013/09/28 职场文书
最新的大学生找工作自我评价
2013/09/29 职场文书
秋天的雨教学反思
2014/04/27 职场文书
责任担保书范文
2014/05/21 职场文书
法学自荐信
2014/06/20 职场文书
2014年世界艾滋病日宣传活动总结
2014/11/18 职场文书
毕业欢送会致辞
2015/07/29 职场文书
旅行社计调工作总结
2015/08/12 职场文书
2016年六一儿童节开幕词
2016/03/04 职场文书
一条 SQL 语句执行过程
2022/03/17 MySQL
gtx1650怎么样 gtx1650显卡相当于什么级别
2022/04/08 数码科技