tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python从网络读取图片并直接进行处理的方法
May 22 Python
Python实现数通设备端口使用情况监控实例
Jul 15 Python
Python实现删除当前目录下除当前脚本以外的文件和文件夹实例
Jul 27 Python
django 捕获异常和日志系统过程详解
Jul 18 Python
django一对多模型以及如何在前端实现详解
Jul 24 Python
对django中foreignkey的简单使用详解
Jul 28 Python
Python Django 实现简单注册功能过程详解
Jul 29 Python
Python二元赋值实用技巧解析
Oct 25 Python
python 实现list或string按指定分段
Dec 25 Python
Python如何使用字符打印照片
Jan 03 Python
pytorch-神经网络拟合曲线实例
Jan 15 Python
Python自动化测试中yaml文件读取操作
Aug 20 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
php通过前序遍历树实现无需递归的无限极分类
2015/07/10 PHP
PHP的Yii框架中Model模型的学习教程
2016/03/29 PHP
通过Unicode转义序列来加密,按你说的可以算是混淆吧
2007/05/06 Javascript
msn上的tab功能Firefox对childNodes处理的一个BUG
2008/01/21 Javascript
6款新颖的jQuery和CSS3进度条插件推荐
2013/03/05 Javascript
判断输入是否为空,获得输入类型的JS代码
2013/10/30 Javascript
javascript修改图片src的方法
2015/01/27 Javascript
简述JavaScript对传统文档对象模型的支持
2015/06/16 Javascript
Javascript动画效果(1)
2016/10/11 Javascript
教你快速搭建Node.Js服务器的方法教程
2017/03/30 Javascript
详解webpack解惑:require的五种用法
2017/06/09 Javascript
详解Vue组件之作用域插槽
2018/11/22 Javascript
浅谈Vue的响应式原理
2019/05/30 Javascript
在小程序中推送模板消息的实现方法
2019/07/22 Javascript
JQuery实现折叠式菜单的详细代码
2020/06/03 jQuery
python制作小说爬虫实录
2017/08/14 Python
Python实现识别手写数字大纲
2018/01/29 Python
利用python GDAL库读写geotiff格式的遥感影像方法
2018/11/29 Python
基于PyQt4和PySide实现输入对话框效果
2019/02/27 Python
Python使用POP3和SMTP协议收发邮件的示例代码
2019/04/16 Python
Python实现我的世界小游戏源代码
2021/03/02 Python
HTML5 Plus 实现手机APP拍照或相册选择图片上传功能
2016/07/13 HTML / CSS
使用HTML5里的classList操作CSS类
2016/06/28 HTML / CSS
小狗电器官方商城:中国高端吸尘器品牌
2017/03/29 全球购物
Styleonme中文网:韩国高档人气品牌
2017/06/21 全球购物
项目管理计划书
2014/01/09 职场文书
授权委托书样本及填写说明
2014/09/19 职场文书
三严三实心得体会范文
2014/10/13 职场文书
节约用电倡议书
2015/04/28 职场文书
2015年大学学生会工作总结
2015/05/13 职场文书
刑事附带民事上诉状
2015/05/23 职场文书
禁毒主题班会教案
2015/08/14 职场文书
告诉你一个秘密:富人致富的五大优点
2019/07/11 职场文书
详解Python魔法方法之描述符类
2021/05/26 Python
Redis集群节点通信过程/原理流程分析
2022/03/18 Redis
SpringCloud超详细讲解Feign声明式服务调用
2022/06/21 Java/Android