Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中常用的各种数据库操作模块和连接实例
May 29 Python
解析Python编程中的包结构
Oct 25 Python
Python解决N阶台阶走法问题的方法分析
Dec 28 Python
Python中的并发处理之asyncio包使用的详解
Apr 03 Python
将pandas.dataframe的数据写入到文件中的方法
Dec 07 Python
Python数据集切分实例
Dec 08 Python
pytorch中如何使用DataLoader对数据集进行批处理的方法
Aug 06 Python
pygame实现贪吃蛇游戏(下)
Oct 29 Python
提升python处理速度原理及方法实例
Dec 25 Python
详解Python的三种拷贝方式
Feb 11 Python
如何让python的运行速度得到提升
Jul 08 Python
python 将列表里的字典元素合并为一个字典实例
Sep 01 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
在Zeus Web Server中安装PHP语言支持
2006/10/09 PHP
php中Smarty模板初体验
2011/08/08 PHP
PHP动态生成指定大小随机图片的方法
2016/03/25 PHP
Laravel Intervention/image图片处理扩展包的安装、使用与可能遇到的坑详解
2017/11/14 PHP
php+ajax实现文件切割上传功能示例
2020/03/03 PHP
JQuery与JSon实现的无刷新分页代码
2011/09/13 Javascript
javascript针对DOM的应用分析(三)
2012/04/15 Javascript
21个值得收藏的Javascript技巧
2014/02/04 Javascript
jQuery构造函数init参数分析续
2015/05/13 Javascript
jQuery控制DIV层实现由大到小,由远及近动画变化效果
2015/10/09 Javascript
JavaScript实现的背景自动变色代码
2015/10/17 Javascript
jquery拼接ajax 的json和字符串拼接的方法
2017/03/11 Javascript
微信小程序表单验证功能完整实例
2017/12/01 Javascript
使用Vue.observable()进行状态管理的实例代码详解
2019/05/26 Javascript
Websocket 向指定用户发消息的方法
2020/01/09 Javascript
JS+CSS实现炫酷光感效果
2020/09/05 Javascript
vue实现点击出现操作弹出框的示例
2020/11/05 Javascript
在Python3中初学者应会的一些基本的提升效率的小技巧
2015/03/31 Python
Python图像灰度变换及图像数组操作
2016/01/27 Python
python3实现读取chrome浏览器cookie
2016/06/19 Python
Python正确重载运算符的方法示例详解
2017/08/27 Python
详解appium+python 启动一个app步骤
2017/12/20 Python
对numpy中array和asarray的区别详解
2018/04/17 Python
Python字符串匹配之6种方法的使用详解
2019/04/08 Python
Tensorflow中tf.ConfigProto()的用法详解
2020/02/06 Python
Pycharm和Idea支持的vim插件的方法
2020/02/21 Python
Python多线程正确用法实例解析
2020/05/30 Python
分厂厂长岗位职责
2013/12/29 职场文书
优秀实习生感言
2014/03/01 职场文书
暑期学习心得体会
2014/09/02 职场文书
拔河比赛队名及霸气口号
2015/12/24 职场文书
党组织关系的介绍信模板
2019/06/21 职场文书
浅谈python中的多态
2021/06/15 Python
使用Ajax实现无刷新上传文件
2022/04/12 Javascript
python和anaconda的区别
2022/05/06 Python
GO中sync包自由控制并发示例详解
2022/08/05 Golang