使用Keras构造简单的CNN网络实例


Posted in Python onJune 29, 2020

1. 导入各种模块

基本形式为:

import 模块名

from 某个文件 import 某个模块

2. 导入数据(以两类分类问题为例,即numClass = 2)

训练集数据data

可以看到,data是一个四维的ndarray

训练集的标签

3. 将导入的数据转化我keras可以接受的数据格式

keras要求的label格式应该为binary class matrices,所以,需要对输入的label数据进行转化,利用keras提高的to_categorical函数

label = np_utils.to_categorical(label, numClass

此时的label变为了如下形式

(注:PyCharm无法显示那么多的数据,所以下面才只显示了1000个数据,实际上该例子所示的数据集有1223个数据)

4. 建立CNN模型

以下图所示的CNN网络为例

#生成一个model
model = Sequential()
 
#layer1-conv1
model.add(Convolution2D(16, 3, 3, border_mode='valid',input_shape=data.shape[-3:]))
model.add(Activation('tanh'))#tanh
 
# layer2-conv2
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer3-conv3
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer4
model.add(Flatten())
model.add(Dense(128, init='normal'))
model.add(Activation('tanh'))#tanh
 
# layer5-fully connect
model.add(Dense(numClass, init='normal')) 
model.add(Activation('softmax'))

# 
sgd = SGD(l2=0.1,lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd,class_mode="categorical")

5. 开始训练model

利用model.train_on_batch或者model.fit

补充知识:keras 多分类一些函数参数设置

用Lenet-5 识别Mnist数据集为例子:

采用下载好的Mnist数据压缩包转换成PNG图片数据集,加载图片采用keras图像预处理模块中的ImageDataGenerator。

首先import所需要的模块

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import MaxPooling2D,Input,Convolution2D
from keras.layers import Dropout, Flatten, Dense
from keras import backend as K

定义图像数据信息及训练参数

img_width, img_height = 28, 28 
train_data_dir = 'dataMnist/train' #train data directory
validation_data_dir = 'dataMnist/validation'# validation data directory
nb_train_samples = 60000 
nb_validation_samples = 10000
epochs = 50 
batch_size = 32

判断使用的后台

if K.image_dim_ordering() == 'th':
 input_shape = (3, img_width, img_height)
else:
 input_shape = (img_width, img_height, 3)

网络模型定义

主要注意最后的输出层定义

比如Mnist数据集是要对0~9这10种手写字符进行分类,那么网络的输出层就应该输出一个10维的向量,10维向量的每一维代表该类别的预测概率,所以此处输出层的定义为:

x = Dense(10,activation='softmax')(x)

此处因为是多分类问题,Dense()的第一个参数代表输出层节点数,要输出10类则此项值为10,激活函数采用softmax,如果是二分类问题第一个参数可以是1,激活函数可选sigmoid

img_input=Input(shape=input_shape)
x=Convolution2D(32, 3, 3, activation='relu', border_mode='same')(img_input)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(32,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(64,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x = Flatten(name='flatten')(x)
x = Dense(64, activation='relu')(x)
x= Dropout(0.5)(x)
x = Dense(10,activation='softmax')(x)
model=Model(img_input,x)


model.compile(loss='binary_crossentropy',
    optimizer='rmsprop',
    metrics=['accuracy'])
model.summary()

利用ImageDataGenerator传入图像数据集

注意用ImageDataGenerator的方法.flow_from_directory()加载图片数据流时,参数class_mode要设为‘categorical',如果是二分类问题该值可设为‘binary',另外要设置classes参数为10种类别数字所在文件夹的名字,以列表的形式传入。

train_datagen = ImageDataGenerator(
 rescale=1. / 255,
 shear_range=0.2,
 zoom_range=0.2,
 horizontal_flip=True)

# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
 train_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical', #多分类问题设为'categorical'
 classes=['0','1','2','3','4','5','6','7','8','9'] #十种数字图片所在文件夹的名字
 )

validation_generator = test_datagen.flow_from_directory(
 validation_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical'
 )

训练和保存模型及权值

model.fit_generator(
  train_generator,
  samples_per_epoch=nb_train_samples,
  nb_epoch=epochs,
  validation_data=validation_generator,
  nb_val_samples=nb_validation_samples
  )

model.save_weights('Mnist123weight.h5')
model.save('Mnist123model.h5')

至此训练结束

图片预测

注意model.save()可以将模型以及权值一起保存,而model.save_weights()只保存了网络权值,此时如果要进行预测,必须定义有和训练出该权值所用的网络结构一模一样的一个网络。

此处利用keras.models中的load_model方法加载model.save()所保存的模型,以恢复网络结构和参数。

from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img
import numpy as np
classes=['0','1','2','3','4','5','6','7','8','9']
model=load_model('Mnist123model.h5')
while True:
 img_addr=input('Please input your image address:')
 if img_addr=="exit":
  break
 else:
  img = load_img(img_addr, False, target_size=(28, 28))
  x = img_to_array(img) / 255.0
  x = np.expand_dims(x, axis=0)
  result = model.predict(x)
  ind=np.argmax(result,1)
  print('this is a ', classes[ind])

以上这篇使用Keras构造简单的CNN网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现从网络下载文件并获得文件大小及类型的方法
Apr 28 Python
Python利用IPython提高开发效率
Aug 10 Python
python安装oracle扩展及数据库连接方法
Feb 21 Python
pandas series序列转化为星期几的实例
Apr 11 Python
Selenium鼠标与键盘事件常用操作方法示例
Aug 13 Python
python开发之anaconda以及win7下安装gensim的方法
Jul 05 Python
Python paramiko 模块浅谈与SSH主要功能模拟解析
Feb 29 Python
Python中求对数方法总结
Mar 10 Python
Python使用Chrome插件实现爬虫过程图解
Jun 09 Python
深入了解Python 变量作用域
Jul 24 Python
python中doctest库实例用法
Dec 31 Python
python绘图模块之利用turtle画图
Feb 12 Python
基于K.image_data_format() == 'channels_first' 的理解
Jun 29 #Python
Python enumerate() 函数如何实现索引功能
Jun 29 #Python
解决Keras中CNN输入维度报错问题
Jun 29 #Python
Python字符串split及rsplit方法原理详解
Jun 29 #Python
浅谈Keras参数 input_shape、input_dim和input_length用法
Jun 29 #Python
使用 prometheus python 库编写自定义指标的方法(完整代码)
Jun 29 #Python
使用keras时input_shape的维度表示问题说明
Jun 29 #Python
You might like
PHP中array_keys和array_unique函数源码的分析
2016/02/26 PHP
详谈phpAdmin修改密码后拒绝访问的问题
2017/04/03 PHP
Linux下 php7安装redis的方法
2018/11/01 PHP
不常用但很实用的PHP预定义变量分析
2019/06/25 PHP
Laravel框架实现的上传图片到七牛功能详解
2019/09/06 PHP
jquery ajax return没有返回值的解决方法
2011/10/20 Javascript
用js通过url传参把数据从一个页面传到另一个页面
2014/09/01 Javascript
原生javascript实现图片弹窗交互效果
2015/01/12 Javascript
JS实现同时搜索百度和必应的方法
2015/01/27 Javascript
jquery实现用户信息修改验证输入方法汇总
2015/07/18 Javascript
JS实现Fisheye效果动感放大菜单代码
2015/10/21 Javascript
Jquery Mobile 自定义按钮图标
2015/11/18 Javascript
Javascript闭包实例详解
2015/11/29 Javascript
Eclipse编辑jsp、js文件时卡死现象的解决办法汇总
2016/02/02 Javascript
Javascript 基础---Ajax入门必看
2016/07/06 Javascript
jQuery实现的自定义滚动条实例详解
2016/09/20 Javascript
vue实现可增删查改的成绩单
2016/10/27 Javascript
EditPlus中的正则表达式 实战(4)
2016/12/15 Javascript
Vue.js仿Metronic高级表格(一)静态设计
2017/04/17 Javascript
js中变量的连续赋值(实例讲解)
2017/07/08 Javascript
解决Vue使用mint-ui loadmore实现上拉加载与下拉刷新出现一个页面使用多个上拉加载后冲突问题
2017/11/07 Javascript
Vue的实例、生命周期与Vue脚手架(vue-cli)实例详解
2017/12/27 Javascript
详解关于微信setData回调函数中的坑
2019/02/18 Javascript
关于layui时间回显问题的解决方法
2019/09/24 Javascript
vue-admin-template配置快捷导航的代码(标签导航栏)
2020/09/04 Javascript
python读取html中指定元素生成excle文件示例
2014/04/03 Python
使用Numpy读取CSV文件,并进行行列删除的操作方法
2018/07/04 Python
Anaconda2 5.2.0安装使用图文教程
2018/09/19 Python
Django自关联实现多级联动查询实例
2020/05/19 Python
使用layui框架实现点击左侧导航切换右侧内容且右侧选项卡跟随变化的效果
2020/11/10 HTML / CSS
双立人美国官方商店:ZWILLING集团餐具和炊具
2020/05/07 全球购物
幼儿园教师读书笔记
2015/06/29 职场文书
2015中学教师个人工作总结
2015/07/22 职场文书
学习十八大的感悟
2015/08/11 职场文书
如何做好员工培训计划?
2019/07/09 职场文书
Flask response响应的具体使用
2021/07/15 Python