python tensorflow基于cnn实现手写数字识别


Posted in Python onJanuary 01, 2018

一份基于cnn的手写数字自识别的代码,供大家参考,具体内容如下

# -*- coding: utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 加载数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# 以交互式方式启动session
# 如果不使用交互式session,则在启动session前必须
# 构建整个计算图,才能启动该计算图
sess = tf.InteractiveSession()

"""构建计算图"""
# 通过占位符来为输入图像和目标输出类别创建节点
# shape参数是可选的,有了它tensorflow可以自动捕获维度不一致导致的错误
x = tf.placeholder("float", shape=[None, 784]) # 原始输入
y_ = tf.placeholder("float", shape=[None, 10]) # 目标值

# 为了不在建立模型的时候反复做初始化操作,
# 我们定义两个函数用于初始化
def weight_variable(shape):
 # 截尾正态分布,stddev是正态分布的标准偏差
 initial = tf.truncated_normal(shape=shape, stddev=0.1)
 return tf.Variable(initial)
def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

# 卷积核池化,步长为1,0边距
def conv2d(x, W):
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
       strides=[1, 2, 2, 1], padding='SAME')

"""第一层卷积"""
# 由一个卷积和一个最大池化组成。滤波器5x5中算出32个特征,是因为使用32个滤波器进行卷积
# 卷积的权重张量形状是[5, 5, 1, 32],1是输入通道的个数,32是输出通道个数
W_conv1 = weight_variable([5, 5, 1, 32])
# 每一个输出通道都有一个偏置量
b_conv1 = bias_variable([32])

# 位了使用卷积,必须将输入转换成4维向量,2、3维表示图片的宽、高
# 最后一维表示图片的颜色通道(因为是灰度图像所以通道数维1,RGB图像通道数为3)
x_image = tf.reshape(x, [-1, 28, 28, 1])

# 第一层的卷积结果,使用Relu作为激活函数
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1))
# 第一层卷积后的池化结果
h_pool1 = max_pool_2x2(h_conv1)

"""第二层卷积"""
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

"""全连接层"""
# 图片尺寸减小到7*7,加入一个有1024个神经元的全连接层
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
# 将最后的池化层输出张量reshape成一维向量
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
# 全连接层的输出
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

"""使用Dropout减少过拟合"""
# 使用placeholder占位符来表示神经元的输出在dropout中保持不变的概率
# 在训练的过程中启用dropout,在测试过程中关闭dropout
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

"""输出层"""
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
# 模型预测输出
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

# 交叉熵损失
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))

# 模型训练,使用AdamOptimizer来做梯度最速下降
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

# 正确预测,得到True或False的List
correct_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1))
# 将布尔值转化成浮点数,取平均值作为精确度
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

# 在session中先初始化变量才能在session中调用
sess.run(tf.global_variables_initializer())

# 迭代优化模型
for i in range(2000):
 # 每次取50个样本进行训练
 batch = mnist.train.next_batch(50)
 if i%100 == 0:
  train_accuracy = accuracy.eval(feed_dict={
   x: batch[0], y_: batch[1], keep_prob: 1.0}) # 模型中间不使用dropout
  print("step %d, training accuracy %g" % (i, train_accuracy))
 train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob: 0.5})
print("test accuracy %g" % accuracy.eval(feed_dict={
   x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

做了2000次迭代,在测试集上的识别精度能够到0.9772……

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python构建Hopfield网络的教程
Apr 14 Python
Python中random模块用法实例分析
May 19 Python
python套接字流重定向实例汇总
Mar 03 Python
Python利用字典将两个通讯录文本合并为一个文本实例
Jan 16 Python
Python爬虫抓取代理IP并检验可用性的实例
May 07 Python
python实现对任意大小图片均匀切割的示例
Dec 05 Python
浅谈Python批处理文件夹中的txt文件
Mar 11 Python
感知器基础原理及python实现过程详解
Sep 30 Python
Python笔记之观察者模式
Nov 20 Python
Python 序列化和反序列化库 MarshMallow 的用法实例代码
Feb 25 Python
解决pytorch 数据类型报错的问题
Mar 03 Python
浅谈tf.train.Saver()与tf.train.import_meta_graph的要点
May 26 Python
python+selenium实现163邮箱自动登陆的方法
Dec 31 #Python
python 类对象和实例对象动态添加方法(分享)
Dec 31 #Python
利用python将图片转换成excel文档格式
Dec 30 #Python
书单|人生苦短,你还不用python!
Dec 29 #Python
python ansible服务及剧本编写
Dec 29 #Python
详解python 拆包可迭代数据如tuple, list
Dec 29 #Python
详解Python异常处理中的Finally else的功能
Dec 29 #Python
You might like
织梦模板标记简介
2007/03/11 PHP
数据库查询记录php 多行多列显示
2009/08/15 PHP
libmysql.dll与php.ini是否真的要拷贝到c:\windows目录下呢
2010/03/15 PHP
php常用的url处理函数总结
2014/11/19 PHP
十幅图告诉你什么是PHP引用
2015/02/22 PHP
JavaScript 私有成员分析
2009/01/13 Javascript
IE6/7/8中Option元素未设value时Select将获取空字符串
2011/04/07 Javascript
动态加载外部javascript文件的函数代码分享
2011/07/28 Javascript
js 手机号码合法性验证代码集合
2012/09/29 Javascript
javascript窗口宽高,鼠标位置,滚动高度(详细解析)
2013/11/18 Javascript
JS 新增Cookie 取cookie值 删除cookie 举例详解
2014/10/10 Javascript
浅谈jQuery animate easing的具体使用方法(推荐)
2016/06/17 Javascript
微信+angularJS的SPA应用中用router进行页面跳转,jssdk校验失败问题解决
2016/09/09 Javascript
AngularJS通过ng-Img-Crop实现头像截取的示例
2017/08/17 Javascript
web前端开发中常见的多列布局解决方案整理(一定要看)
2017/10/15 Javascript
vue将时间戳转换成自定义时间格式的方法
2018/03/02 Javascript
深入浅析Node.js 事件循环、定时器和process.nextTick()
2018/10/22 Javascript
vue组件从开发到发布的实现步骤
2018/11/11 Javascript
jQuery+ajax实现批量删除功能完整示例
2019/06/06 jQuery
通过实例解析json与jsonp原理及使用方法
2020/09/27 Javascript
Google开源的Python格式化工具YAPF的安装和使用教程
2016/05/31 Python
Python实现生成密码字典的方法示例
2019/09/02 Python
浅谈Pycharm最有必要改的几个默认设置项
2020/02/14 Python
Python编程快速上手——strip()函数的正则表达式实现方法分析
2020/02/29 Python
如何使用python-opencv批量生成带噪点噪线的数字验证码
2020/12/21 Python
python元组拆包实现方法
2021/02/28 Python
Lookfantastic澳大利亚官网:英国知名美妆购物网站
2021/01/07 全球购物
介绍一下linux的文件权限
2014/07/20 面试题
财务会计专业毕业生自荐信
2013/10/19 职场文书
大专应届生个人的自我评价
2013/11/21 职场文书
个性发展自我评价
2014/02/11 职场文书
大三学生做职业规划:给未来找个方向
2014/02/24 职场文书
生物学专业求职信
2014/07/23 职场文书
欢迎新生标语
2014/10/06 职场文书
2015年小学实验室工作总结
2015/07/28 职场文书
人生感悟经典句子
2019/08/20 职场文书