tensorflow使用CNN分析mnist手写体数字数据集


Posted in Python onJune 17, 2020

本文实例为大家分享了tensorflow使用CNN分析mnist手写体数字数据集,供大家参考,具体内容如下

import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#把上述trX和teX的形状变为[-1,28,28,1],-1表示不考虑输入图片的数量,28×28是图片的长和宽的像素数,
# 1是通道(channel)数量,因为MNIST的图片是黑白的,所以通道是1,如果是RGB彩色图像,通道是3。
trX = trX.reshape(-1, 28, 28, 1) # 28x28x1 input img
teX = teX.reshape(-1, 28, 28, 1) # 28x28x1 input img
 
X = tf.placeholder("float", [None, 28, 28, 1])
Y = tf.placeholder("float", [None, 10])
#初始化权重与定义网络结构。
# 这里,我们将要构建一个拥有3个卷积层和3个池化层,随后接1个全连接层和1个输出层的卷积神经网络
def init_weights(shape):
 return tf.Variable(tf.random_normal(shape, stddev=0.01))
 
w = init_weights([3, 3, 1, 32])   # patch大小为3×3,输入维度为1,输出维度为32
w2 = init_weights([3, 3, 32, 64])   # patch大小为3×3,输入维度为32,输出维度为64
w3 = init_weights([3, 3, 64, 128])   # patch大小为3×3,输入维度为64,输出维度为128
w4 = init_weights([128 * 4 * 4, 625])  # 全连接层,输入维度为 128 × 4 × 4,是上一层的输出数据又三维的转变成一维, 输出维度为625
w_o = init_weights([625, 10]) # 输出层,输入维度为 625, 输出维度为10,代表10类(labels)
# 神经网络模型的构建函数,传入以下参数
# X:输入数据
# w:每一层的权重
# p_keep_conv,p_keep_hidden:dropout要保留的神经元比例
 
def model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden):
 # 第一组卷积层及池化层,最后dropout一些神经元
 l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
 # l1a shape=(?, 28, 28, 32)
 l1 = tf.nn.max_pool(l1a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l1 shape=(?, 14, 14, 32)
 l1 = tf.nn.dropout(l1, p_keep_conv)
 
 # 第二组卷积层及池化层,最后dropout一些神经元
 l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
 # l2a shape=(?, 14, 14, 64)
 l2 = tf.nn.max_pool(l2a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l2 shape=(?, 7, 7, 64)
 l2 = tf.nn.dropout(l2, p_keep_conv)
 # 第三组卷积层及池化层,最后dropout一些神经元
 l3a = tf.nn.relu(tf.nn.conv2d(l2, w3, strides=[1, 1, 1, 1], padding='SAME'))
 # l3a shape=(?, 7, 7, 128)
 l3 = tf.nn.max_pool(l3a, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
 # l3 shape=(?, 4, 4, 128)
 l3 = tf.reshape(l3, [-1, w4.get_shape().as_list()[0]]) # reshape to (?, 2048)
 l3 = tf.nn.dropout(l3, p_keep_conv)
 # 全连接层,最后dropout一些神经元
 l4 = tf.nn.relu(tf.matmul(l3, w4))
 l4 = tf.nn.dropout(l4, p_keep_hidden)
 # 输出层
 pyx = tf.matmul(l4, w_o)
 return pyx #返回预测值
 
#我们定义dropout的占位符——keep_conv,它表示在一层中有多少比例的神经元被保留下来。生成网络模型,得到预测值
p_keep_conv = tf.placeholder("float")
p_keep_hidden = tf.placeholder("float")
py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden) #得到预测值
#定义损失函数,这里我们仍然采用tf.nn.softmax_cross_entropy_with_logits来比较预测值和真实值的差异,并做均值处理;
# 定义训练的操作(train_op),采用实现RMSProp算法的优化器tf.train.RMSPropOptimizer,学习率为0.001,衰减值为0.9,使损失最小;
# 定义预测的操作(predict_op)
cost = tf.reduce_mean(tf.nn. softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict_op = tf.argmax(py_x, 1)
#定义训练时的批次大小和评估时的批次大小
batch_size = 128
test_size = 256
#在一个会话中启动图,开始训练和评估
# Launch the graph in a session
with tf.Session() as sess:
 # you need to initialize all variables
 tf. global_variables_initializer().run()
 for i in range(100):
  training_batch = zip(range(0, len(trX), batch_size),
        range(batch_size, len(trX)+1, batch_size))
  for start, end in training_batch:
   sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
           p_keep_conv: 0.8, p_keep_hidden: 0.5})
 
  test_indices = np.arange(len(teX)) # Get A Test Batch
  np.random.shuffle(test_indices)
  test_indices = test_indices[0:test_size]
 
  print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
       sess.run(predict_op, feed_dict={X: teX[test_indices],
               p_keep_conv: 1.0,
               p_keep_hidden: 1.0})))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
八大排序算法的Python实现
Jan 28 Python
Python 爬虫模拟登陆知乎
Sep 23 Python
Python基于递归和非递归算法求两个数最大公约数、最小公倍数示例
May 21 Python
Python面向对象之类的定义与继承用法示例
Jan 14 Python
pytorch 在sequential中使用view来reshape的例子
Aug 20 Python
scrapy数据存储在mysql数据库的两种方式(同步和异步)
Feb 18 Python
windows python3安装Jupyter Notebooks教程
Apr 13 Python
Python接口测试结果集实现封装比较
May 01 Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 Python
python+pygame实现坦克大战小游戏的示例代码(可以自定义子弹速度)
Aug 11 Python
用python实现一个简单的验证码
Dec 09 Python
python 中[0]*2与0*2的区别说明
May 10 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
You might like
php基础知识:类与对象(4) 范围解析操作符(::)
2006/12/13 PHP
php Notice: Undefined index 错误提示解决方法
2010/08/29 PHP
PHP操作数组的一些函数整理介绍
2011/07/17 PHP
组合算法的PHP解答方法
2012/02/04 PHP
php仿QQ验证码的实例分析
2013/07/01 PHP
ThinkPHP3.1之D方法实例详解
2014/06/20 PHP
PHP SOCKET编程详解
2015/05/22 PHP
深入解析PHP的Laravel框架中的event事件操作
2016/03/21 PHP
php 输出缓冲 Output Control用法实例详解
2020/03/03 PHP
js最简单的拖拽效果实现代码
2010/09/24 Javascript
理解Javascript_10_对象模型
2010/10/16 Javascript
Jquery实现鼠标移上弹出提示框、移出消失思路及代码
2013/05/19 Javascript
jquery 面包屑导航 具体实现
2013/06/05 Javascript
jquery实现树形二级菜单实例代码
2013/11/20 Javascript
Javascript中3种实现继承的方法和代码实例
2014/08/12 Javascript
JavaScript基础语法、dom操作树及document对象
2014/12/02 Javascript
transport.js和jquery冲突问题的解决方法
2015/02/10 Javascript
关于backbone url请求中参数带有中文存入数据库是乱码的快速解决办法
2016/06/13 Javascript
jQuery Ajax传值到Servlet出现乱码问题的解决方法
2016/10/09 Javascript
AngularJS中如何使用echart插件示例详解
2016/10/26 Javascript
JavaScript实现经典排序算法之插入排序
2016/12/28 Javascript
Javascript中document.referrer隐藏来源的方法
2017/01/16 Javascript
ES6学习教程之对象的扩展详解
2017/05/02 Javascript
详解JavaScript数组过滤相同元素的5种方法
2017/05/23 Javascript
用node撸一个监测复联4开售短信提醒的实现代码
2019/04/10 Javascript
详解vue项目中调用百度地图API使用方法
2019/04/25 Javascript
原生JS实现无缝轮播图片
2020/06/24 Javascript
正确理解Python中if __name__ == '__main__'
2019/01/24 Python
从0开始的Python学习016异常
2019/04/08 Python
PyCharm License Activation激活码失效问题的解决方法(图文详解)
2020/03/12 Python
Python figure参数及subplot子图绘制代码
2020/04/18 Python
HTML5 Canvas 实现圆形进度条并显示数字百分比效果示例
2017/08/18 HTML / CSS
香港彩色隐形眼镜在线商店:Stunninglens(全球免费送货)
2019/05/10 全球购物
运动会稿件50字
2014/02/17 职场文书
普通大学毕业生自荐信范文
2014/02/23 职场文书
行政管理毕业生自荐信
2014/02/24 职场文书