Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django框架中的对象列表视图使用示例
Jul 21 Python
Python实现数据库并行读取和写入实例
Jun 09 Python
Python实现感知器模型、两层神经网络
Dec 19 Python
Python编写Windows Service服务程序
Jan 04 Python
python3+dlib实现人脸识别和情绪分析
Apr 21 Python
Python定义一个跨越多行的字符串的多种方法小结
Jul 19 Python
Python初学者需要注意的事项小结(python2与python3)
Sep 26 Python
在python中利用GDAL对tif文件进行读写的方法
Nov 29 Python
wxPython色环电阻计算器
Nov 18 Python
Python基础之函数原理与应用实例详解
Jan 03 Python
详解用Python进行时间序列预测的7种方法
Mar 13 Python
keras的三种模型实现与区别说明
Jul 03 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
自己前几天写的无限分类类
2007/02/14 PHP
PHP中函数rand和mt_rand的区别比较
2012/12/26 PHP
PHP实现的连贯操作、链式操作实例
2014/07/08 PHP
php使用ereg验证文件上传的方法
2014/12/16 PHP
深入探究PHP的多进程编程方法
2015/08/18 PHP
JavaScript脚本语言在网页中的简单应用
2007/05/13 Javascript
Jquery实现遮罩层的方法
2015/06/08 Javascript
js操作XML文件的实现方法兼容IE与FireFox
2016/06/25 Javascript
AngularJs html compiler详解及示例代码
2016/09/01 Javascript
JS中cookie的使用及缺点讲解
2017/05/13 Javascript
JavaScript实现图片本地预览功能【不用上传至服务器】
2017/09/20 Javascript
微信小程序日历组件使用方法详解
2018/12/29 Javascript
koa-router路由参数和前端路由的结合详解
2019/05/19 Javascript
[27:02]2014 DOTA2国际邀请赛中国区预选赛 5 23 CIS VS LGD第三场
2014/05/24 DOTA
用Python实现通过哈希算法检测图片重复的教程
2015/04/02 Python
解决python 未发现数据源名称并且未指定默认驱动程序的问题
2018/12/07 Python
Python3enumrate和range对比及示例详解
2019/07/13 Python
Python获取一个用户名的组ID过程解析
2019/09/03 Python
python paramiko远程服务器终端操作过程解析
2019/12/14 Python
matplotlib设置颜色、标记、线条,让你的图像更加丰富(推荐)
2020/09/25 Python
几款Python编译器比较与推荐(小结)
2020/10/15 Python
浅析关于Keras的安装(pycharm)和初步理解
2020/10/23 Python
奥斯汀独木舟和皮划艇:Austin Canoe & Kayak
2018/05/22 全球购物
回馈慈善的设计师太阳镜:DIFF eyewear
2019/10/17 全球购物
文员个人的求职信范文
2013/09/26 职场文书
静心口服夜广告词
2014/03/20 职场文书
公司经理聘任书
2014/03/29 职场文书
大学生就业求职信
2014/06/12 职场文书
法人身份证明书
2014/10/08 职场文书
优秀团员自我评价
2015/03/10 职场文书
病假证明模板
2015/06/19 职场文书
机关干部作风整顿心得体会
2016/01/22 职场文书
数学复习课教学反思
2016/02/18 职场文书
《时代广场的蟋蟀》读后感:真挚友情,温暖世界!
2020/01/08 职场文书
为什么mysql字段要使用NOT NULL
2021/05/13 MySQL
Redis 限流器
2022/05/15 Redis