Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python开发常用的一些开源Package分享
Feb 14 Python
python使用三角迭代计算圆周率PI的方法
Mar 20 Python
简单介绍Python中的RSS处理
Apr 13 Python
python实现批量改文件名称的方法
May 25 Python
详解python中requirements.txt的一切
Mar 03 Python
django实现用户登陆功能详解
Dec 11 Python
python opencv之分水岭算法示例
Feb 24 Python
python如何读写json数据
Mar 21 Python
12个步骤教你理解Python装饰器
Jul 01 Python
Python with用法:自动关闭文件进程
Jul 10 Python
Python CSV文件模块的使用案例分析
Dec 21 Python
自定义实现 PyQt5 下拉复选框 ComboCheckBox的完整代码
Mar 30 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
如何删除多级目录
2006/10/09 PHP
php引用地址改变变量值的问题
2012/03/23 PHP
详解ThinkPHP3.2.3验证码显示、刷新、校验
2016/12/29 PHP
PHP创建自己的Composer包方法
2018/04/09 PHP
filemanage功能中用到的lib.js
2007/04/08 Javascript
jQuery弹出层插件简化版代码下载
2008/10/16 Javascript
关于IE、Firefox、Opera页面呈现异同 写脚本很痛苦
2009/08/28 Javascript
50款非常棒的 jQuery 插件分享
2012/03/29 Javascript
js实现飞入星星特效代码
2014/10/17 Javascript
javascript面向对象之定义成员方法实例分析
2015/01/13 Javascript
JS获取数组最大值、最小值及长度的方法
2015/11/24 Javascript
AngularJS教程之简单应用程序示例
2016/08/16 Javascript
javascript数组常用方法汇总
2016/09/10 Javascript
IOS中safari下的select下拉菜单文字过长不换行的解决方法
2016/09/26 Javascript
JS中实现函数return多个返回值的实例
2017/02/21 Javascript
js数字计算 误差问题的快速解决方法
2017/02/28 Javascript
通过学习bootstrop导航条学会修改bootstrop颜色基调
2017/06/11 Javascript
使用Vue-Router 2实现路由功能实例详解
2017/11/14 Javascript
JavaScript交换变量常用4种方法解析
2020/09/02 Javascript
python奇偶行分开存储实现代码
2018/03/19 Python
利用scrapy将爬到的数据保存到mysql(防止重复)
2018/03/31 Python
python提取图像的名字*.jpg到txt文本的方法
2018/05/10 Python
python交互界面的退出方法
2019/02/16 Python
Python编写带选项的命令行程序方法
2019/08/13 Python
Python Opencv提取图片中某种颜色组成的图形的方法
2019/09/19 Python
python-图片流传输的思路及示例(url转换二维码)
2020/12/21 Python
英国内衣连锁店:Boux Avenue
2018/01/24 全球购物
德国黑胶唱片、街头服装及运动鞋网上商店:HHV
2018/08/24 全球购物
应届生体育教师自荐信
2013/10/03 职场文书
党员评议思想汇报
2014/10/08 职场文书
督导岗位职责范本
2015/04/10 职场文书
党员证明信
2015/06/19 职场文书
母亲去世追悼词
2015/06/23 职场文书
《狮子和鹿》教学反思
2016/02/16 职场文书
数据库连接池
2021/04/06 MySQL
HTML+CSS实现导航条下拉菜单的示例代码
2021/08/02 HTML / CSS