使用Keras构造简单的CNN网络实例


Posted in Python onJune 29, 2020

1. 导入各种模块

基本形式为:

import 模块名

from 某个文件 import 某个模块

2. 导入数据(以两类分类问题为例,即numClass = 2)

训练集数据data

可以看到,data是一个四维的ndarray

训练集的标签

3. 将导入的数据转化我keras可以接受的数据格式

keras要求的label格式应该为binary class matrices,所以,需要对输入的label数据进行转化,利用keras提高的to_categorical函数

label = np_utils.to_categorical(label, numClass

此时的label变为了如下形式

(注:PyCharm无法显示那么多的数据,所以下面才只显示了1000个数据,实际上该例子所示的数据集有1223个数据)

4. 建立CNN模型

以下图所示的CNN网络为例

#生成一个model
model = Sequential()
 
#layer1-conv1
model.add(Convolution2D(16, 3, 3, border_mode='valid',input_shape=data.shape[-3:]))
model.add(Activation('tanh'))#tanh
 
# layer2-conv2
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer3-conv3
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer4
model.add(Flatten())
model.add(Dense(128, init='normal'))
model.add(Activation('tanh'))#tanh
 
# layer5-fully connect
model.add(Dense(numClass, init='normal')) 
model.add(Activation('softmax'))

# 
sgd = SGD(l2=0.1,lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd,class_mode="categorical")

5. 开始训练model

利用model.train_on_batch或者model.fit

补充知识:keras 多分类一些函数参数设置

用Lenet-5 识别Mnist数据集为例子:

采用下载好的Mnist数据压缩包转换成PNG图片数据集,加载图片采用keras图像预处理模块中的ImageDataGenerator。

首先import所需要的模块

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import MaxPooling2D,Input,Convolution2D
from keras.layers import Dropout, Flatten, Dense
from keras import backend as K

定义图像数据信息及训练参数

img_width, img_height = 28, 28 
train_data_dir = 'dataMnist/train' #train data directory
validation_data_dir = 'dataMnist/validation'# validation data directory
nb_train_samples = 60000 
nb_validation_samples = 10000
epochs = 50 
batch_size = 32

判断使用的后台

if K.image_dim_ordering() == 'th':
 input_shape = (3, img_width, img_height)
else:
 input_shape = (img_width, img_height, 3)

网络模型定义

主要注意最后的输出层定义

比如Mnist数据集是要对0~9这10种手写字符进行分类,那么网络的输出层就应该输出一个10维的向量,10维向量的每一维代表该类别的预测概率,所以此处输出层的定义为:

x = Dense(10,activation='softmax')(x)

此处因为是多分类问题,Dense()的第一个参数代表输出层节点数,要输出10类则此项值为10,激活函数采用softmax,如果是二分类问题第一个参数可以是1,激活函数可选sigmoid

img_input=Input(shape=input_shape)
x=Convolution2D(32, 3, 3, activation='relu', border_mode='same')(img_input)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(32,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(64,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x = Flatten(name='flatten')(x)
x = Dense(64, activation='relu')(x)
x= Dropout(0.5)(x)
x = Dense(10,activation='softmax')(x)
model=Model(img_input,x)


model.compile(loss='binary_crossentropy',
    optimizer='rmsprop',
    metrics=['accuracy'])
model.summary()

利用ImageDataGenerator传入图像数据集

注意用ImageDataGenerator的方法.flow_from_directory()加载图片数据流时,参数class_mode要设为‘categorical',如果是二分类问题该值可设为‘binary',另外要设置classes参数为10种类别数字所在文件夹的名字,以列表的形式传入。

train_datagen = ImageDataGenerator(
 rescale=1. / 255,
 shear_range=0.2,
 zoom_range=0.2,
 horizontal_flip=True)

# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
 train_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical', #多分类问题设为'categorical'
 classes=['0','1','2','3','4','5','6','7','8','9'] #十种数字图片所在文件夹的名字
 )

validation_generator = test_datagen.flow_from_directory(
 validation_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical'
 )

训练和保存模型及权值

model.fit_generator(
  train_generator,
  samples_per_epoch=nb_train_samples,
  nb_epoch=epochs,
  validation_data=validation_generator,
  nb_val_samples=nb_validation_samples
  )

model.save_weights('Mnist123weight.h5')
model.save('Mnist123model.h5')

至此训练结束

图片预测

注意model.save()可以将模型以及权值一起保存,而model.save_weights()只保存了网络权值,此时如果要进行预测,必须定义有和训练出该权值所用的网络结构一模一样的一个网络。

此处利用keras.models中的load_model方法加载model.save()所保存的模型,以恢复网络结构和参数。

from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img
import numpy as np
classes=['0','1','2','3','4','5','6','7','8','9']
model=load_model('Mnist123model.h5')
while True:
 img_addr=input('Please input your image address:')
 if img_addr=="exit":
  break
 else:
  img = load_img(img_addr, False, target_size=(28, 28))
  x = img_to_array(img) / 255.0
  x = np.expand_dims(x, axis=0)
  result = model.predict(x)
  ind=np.argmax(result,1)
  print('this is a ', classes[ind])

以上这篇使用Keras构造简单的CNN网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中内建函数的简单用法说明
May 05 Python
python3实现UDP协议的服务器和客户端
Jun 14 Python
python用户管理系统的实例讲解
Dec 23 Python
python实现定时自动备份文件到其他主机的实例代码
Feb 23 Python
在python2.7中用numpy.reshape 对图像进行切割的方法
Dec 05 Python
Python从单元素字典中获取key和value的实例
Dec 31 Python
Python3 文章标题关键字提取的例子
Aug 26 Python
Django自定义用户表+自定义admin后台中的字段实例
Nov 18 Python
Python操作Excel把数据分给sheet
May 20 Python
Python3中PyQt5简单实现文件打开及保存
Jun 10 Python
详解Python+OpenCV绘制灰度直方图
Mar 22 Python
python中pd.cut()与pd.qcut()的对比及示例
Jun 16 Python
基于K.image_data_format() == 'channels_first' 的理解
Jun 29 #Python
Python enumerate() 函数如何实现索引功能
Jun 29 #Python
解决Keras中CNN输入维度报错问题
Jun 29 #Python
Python字符串split及rsplit方法原理详解
Jun 29 #Python
浅谈Keras参数 input_shape、input_dim和input_length用法
Jun 29 #Python
使用 prometheus python 库编写自定义指标的方法(完整代码)
Jun 29 #Python
使用keras时input_shape的维度表示问题说明
Jun 29 #Python
You might like
php函数的常用方法及注意之处小结
2011/07/10 PHP
编写Smarty插件在模板中直接加载数据的详细介绍
2013/06/26 PHP
PHP网页游戏学习之Xnova(ogame)源码解读(五)
2014/06/23 PHP
php通过session防url攻击方法
2014/12/10 PHP
深入分析PHP优化及注意事项
2016/07/04 PHP
PHP学习笔记之session
2018/05/06 PHP
Javascript创建自定义对象 创建Object实例添加属性和方法
2012/06/04 Javascript
Js base64 加密解密介绍
2013/10/11 Javascript
单击和双击事件的冲突处理示例代码
2014/04/03 Javascript
jQuery中trigger()方法用法实例
2015/01/19 Javascript
jquery 设置style:display的方法
2015/01/29 Javascript
jqGrid表格应用之新增与删除数据附源码下载
2015/12/02 Javascript
AngularJS控制器之间的通信方式详解
2016/11/03 Javascript
js遍历json对象所有key及根据动态key获取值的方法(必看)
2017/03/09 Javascript
VueJs组件prop验证简单介绍
2017/09/12 Javascript
微信小程序之GET请求的实例详解
2017/09/29 Javascript
讲解vue-router之命名路由和命名视图
2018/05/28 Javascript
layui 表格的属性的显示转换方法
2018/08/14 Javascript
微信小程序首页的分类功能和搜索功能的实现思路及代码详解
2018/09/11 Javascript
从零到一详聊创建Vue工程及遇到的常见问题
2019/04/25 Javascript
vue中img src 动态加载本地json的图片路径写法
2019/04/25 Javascript
关于JavaScript数组去重的一些理解汇总
2020/09/10 Javascript
Python制作爬虫抓取美女图
2016/01/20 Python
详解python使用递归、尾递归、循环三种方式实现斐波那契数列
2018/01/16 Python
python 自定义异常和异常捕捉的方法
2018/10/18 Python
Python 进程之间共享数据(全局变量)的方法
2019/07/16 Python
Python3+Selenium+Chrome实现自动填写WPS表单
2020/02/12 Python
通过代码实例了解Python异常本质
2020/09/16 Python
Python如何批量生成和调用变量
2020/11/21 Python
Proenza Schouler官方网站:纽约女装和配饰品牌
2019/01/03 全球购物
声明struct x1 { . . . }; 和typedef struct { . . . }x2;有什么不同
2012/06/02 面试题
体育老师的教学自我评价分享
2013/11/19 职场文书
违纪检讨书2000字
2014/02/08 职场文书
毕业生求职信
2014/06/10 职场文书
2015年发展党员工作总结报告
2015/03/31 职场文书
校运会通讯稿
2015/07/18 职场文书