keras训练浅层卷积网络并保存和加载模型实例


Posted in Python onJuly 02, 2020

这里我们使用keras定义简单的神经网络全连接层训练MNIST数据集和cifar10数据集:

keras_mnist.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from keras.models import Sequential
from keras.layers.core import Dense
from keras.optimizers import SGD
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
import argparse
# 命令行参数运行
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
args =vars(ap.parse_args())
# 加载数据MNIST,然后归一化到【0,1】,同时使用75%做训练,25%做测试
print("[INFO] loading MNIST (full) dataset")
dataset = datasets.fetch_mldata("MNIST Original", data_home="/home/king/test/python/train/pyimagesearch/nn/data/")
data = dataset.data.astype("float") / 255.0
(trainX, testX, trainY, testY) = train_test_split(data, dataset.target, test_size=0.25)
# 将label进行one-hot编码
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# keras定义网络结构784--256--128--10
model = Sequential()
model.add(Dense(256, input_shape=(784,), activation="relu"))
model.add(Dense(128, activation="relu"))
model.add(Dense(10, activation="softmax"))
# 开始训练
print("[INFO] training network...")
# 0.01的学习率
sgd = SGD(0.01)
# 交叉验证
model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=['accuracy'])
H = model.fit(trainX, trainY, validation_data=(testX, testY), epochs=100, batch_size=128)
# 测试模型和评估
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=128)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=[str(x) for x in lb.classes_]))
# 保存可视化训练结果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 100), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 100), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 100), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 100), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])

使用relu做激活函数:

keras训练浅层卷积网络并保存和加载模型实例

使用sigmoid做激活函数:

keras训练浅层卷积网络并保存和加载模型实例

接着我们自己定义一些modules去实现一个简单的卷基层去训练cifar10数据集:

imagetoarraypreprocessor.py

'''
该函数主要是实现keras的一个细节转换,因为训练的图像时RGB三颜色通道,读取进来的数据是有depth的,keras为了兼容一些后台,默认是按照(height, width, depth)读取,但有时候就要改变成(depth, height, width)
'''
from keras.preprocessing.image import img_to_array
class ImageToArrayPreprocessor:
	def __init__(self, dataFormat=None):
		self.dataFormat = dataFormat
 
	def preprocess(self, image):
		return img_to_array(image, data_format=self.dataFormat)

shallownet.py

'''
定义一个简单的卷基层:
input->conv->Relu->FC
'''
from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.core import Activation, Flatten, Dense
from keras import backend as K
 
class ShallowNet:
	@staticmethod
	def build(width, height, depth, classes):
		model = Sequential()
		inputShape = (height, width, depth)
 
		if K.image_data_format() == "channels_first":
			inputShape = (depth, height, width)
 
		model.add(Conv2D(32, (3, 3), padding="same", input_shape=inputShape))
		model.add(Activation("relu"))
 
		model.add(Flatten())
		model.add(Dense(classes))
		model.add(Activation("softmax"))
 
		return model

然后就是训练代码:

keras_cifar10.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
args = vars(ap.parse_args())
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# 标签0-9代表的类别string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] compiling model...")
opt = SGD(lr=0.0001)
model = ShallowNet.build(width=32, height=32, depth=3, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
 
print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY), batch_size=32, epochs=1000, verbose=1)
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=labelNames))
 
# 保存可视化训练结果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 1000), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 1000), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 1000), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 1000), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])

代码中可以对训练的learning rate进行微调,大概可以接近60%的准确率。

keras训练浅层卷积网络并保存和加载模型实例

keras训练浅层卷积网络并保存和加载模型实例

然后修改下代码可以保存训练模型:

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True, help="path to the output loss/accuracy plot")
ap.add_argument("-m", "--model", required=True, help="path to save train model")
args = vars(ap.parse_args())
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
# 标签0-9代表的类别string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] compiling model...")
opt = SGD(lr=0.005)
model = ShallowNet.build(width=32, height=32, depth=3, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
 
print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY), batch_size=32, epochs=50, verbose=1)
 
model.save(args["model"])
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1), predictions.argmax(axis=1), 
	target_names=labelNames))
 
# 保存可视化训练结果
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 5), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 5), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 5), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 5), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("# Epoch")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args["output"])

命令行运行:

keras训练浅层卷积网络并保存和加载模型实例

我们使用另一个程序来加载上一次训练保存的模型,然后进行测试:

test.py

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
from shallownet import ShallowNet
from keras.optimizers import SGD
from keras.datasets import cifar10
from keras.models import load_model
import matplotlib.pyplot as plt
import numpy as np
import argparse
 
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True, help="path to save train model")
args = vars(ap.parse_args())
 
# 标签0-9代表的类别string
labelNames = ['airplane', 'automobile', 'bird', 'cat', 
	'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
 
print("[INFO] loading CIFAR-10 dataset")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
 
idxs = np.random.randint(0, len(testX), size=(10,))
testX = testX[idxs]
testY = testY[idxs]
 
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
 
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)
 
print("[INFO] loading pre-trained network...")
model = load_model(args["model"])
 
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32).argmax(axis=1)
print("predictions\n", predictions)
for i in range(len(testY)):
	print("label:{}".format(labelNames[predictions[i]]))
 
trueLabel = []
for i in range(len(testY)):
	for j in range(len(testY[i])):
		if testY[i][j] != 0:
			trueLabel.append(j)
print(trueLabel)
 
print("ground truth testY:")
for i in range(len(trueLabel)):
	print("label:{}".format(labelNames[trueLabel[i]]))
 
print("TestY\n", testY)

keras训练浅层卷积网络并保存和加载模型实例

以上这篇keras训练浅层卷积网络并保存和加载模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现对比不同字体中的同一字符的显示效果
Apr 23 Python
用Python程序抓取网页的HTML信息的一个小实例
May 02 Python
Python Sql数据库增删改查操作简单封装
Apr 18 Python
python实现决策树C4.5算法详解(在ID3基础上改进)
May 31 Python
Python 实现删除某路径下文件及文件夹的实例讲解
Apr 24 Python
Flask框架信号用法实例分析
Jul 24 Python
python绘制漏斗图步骤详解
Mar 04 Python
python操作日志的封装方法(两种方法)
May 23 Python
Python实现序列化及csv文件读取
Jan 19 Python
Django import export实现数据库导入导出方式
Apr 03 Python
Pycharm2020.1安装无法启动问题即设置中文插件的方法
Aug 07 Python
一小时学会TensorFlow2之基本操作2实例代码
Sep 04 Python
Python RabbitMQ实现简单的进程间通信示例
Jul 02 #Python
利用scikitlearn画ROC曲线实例
Jul 02 #Python
Python使用文件操作实现一个XX信息管理系统的示例
Jul 02 #Python
keras用auc做metrics以及早停实例
Jul 02 #Python
keras 简单 lstm实例(基于one-hot编码)
Jul 02 #Python
Python装饰器结合递归原理解析
Jul 02 #Python
Python OpenCV读取中文路径图像的方法
Jul 02 #Python
You might like
discuz 首页四格:最新话题+最新回复+热门话题+精华文章插件
2007/08/19 PHP
PHP JSON 数据解析代码
2010/05/26 PHP
php中$美元符号与Zen Coding冲突问题解决方法分享
2014/05/28 PHP
在Mac OS的PHP环境下安装配置MemCache的全过程解析
2016/02/15 PHP
php实现QQ小程序发送模板消息功能
2019/09/18 PHP
laravel-admin 在列表页添加自定义按钮的例子
2019/09/30 PHP
jQuery页面滚动浮动层智能定位实例代码
2011/08/23 Javascript
Thinkphp模板没有解析直接原样输出的解决方法
2014/10/31 Javascript
浅谈Javascript如何实现匀速运动
2014/12/19 Javascript
利用js实现禁止复制文本信息
2015/06/03 Javascript
JavaScript事件类型中UI事件详解
2016/01/14 Javascript
通过Tabs方法基于easyUI+bootstrap制作工作站
2016/03/28 Javascript
一些实用性较高的js方法
2016/04/19 Javascript
浅谈js函数中的实例对象、类对象、局部变量(局部函数)
2016/11/20 Javascript
jquery 多个radio的click事件实例
2016/12/03 Javascript
Angular2 http jsonp的实例详解
2017/08/31 Javascript
vue获取dom元素注意事项
2017/12/28 Javascript
微信小程序学习笔记之函数定义、页面渲染图文详解
2019/03/28 Javascript
微信小程序系列之自定义顶部导航功能
2019/05/21 Javascript
jQuery Migrate 插件用法实例详解
2019/05/22 jQuery
layui在form表单页面通过Validform加入简单验证的方法
2019/09/06 Javascript
基于Electron实现桌面应用开发代码实例
2020/07/07 Javascript
js实现鼠标滑动到某个div禁止滚动
2020/09/17 Javascript
django模型中的字段和model名显示为中文小技巧分享
2014/11/18 Python
通过数据库向Django模型添加字段的示例
2015/07/21 Python
pip install urllib2不能安装的解决方法
2018/06/12 Python
对Python中列表和数组的赋值,浅拷贝和深拷贝的实例讲解
2018/06/28 Python
django多对多表的创建,级联删除及手动创建第三张表
2019/07/25 Python
Python实现快速排序的方法详解
2019/10/25 Python
HTML5事件方法全部汇总
2016/05/12 HTML / CSS
美国婴童服装市场上的领先品牌:Carter’s
2018/02/08 全球购物
大课间活动制度
2014/01/18 职场文书
优秀导游先进事迹材料
2014/01/25 职场文书
2014两会学习心得:榜样精神伴我行
2014/03/17 职场文书
政工师工作总结2015
2015/05/26 职场文书
导游词之绍兴柯岩古镇
2020/01/09 职场文书