Tensorflow训练MNIST手写数字识别模型


Posted in Python onFebruary 13, 2020

本文实例为大家分享了Tensorflow训练MNIST手写数字识别模型的具体代码,供大家参考,具体内容如下

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
 
INPUT_NODE = 784  # 输入层节点=图片像素=28x28=784
OUTPUT_NODE = 10  # 输出层节点数=图片类别数目
 
LAYER1_NODE = 500  # 隐藏层节点数,只有一个隐藏层
BATCH_SIZE = 100  # 一个训练包中的数据个数,数字越小
          # 越接近随机梯度下降,越大越接近梯度下降
 
LEARNING_RATE_BASE = 0.8   # 基础学习率
LEARNING_RATE_DECAY = 0.99  # 学习率衰减率
 
REGULARIZATION_RATE = 0.0001  # 正则化项系数
TRAINING_STEPS = 30000     # 训练轮数
MOVING_AVG_DECAY = 0.99    # 滑动平均衰减率
 
# 定义一个辅助函数,给定神经网络的输入和所有参数,计算神经网络的前向传播结果
def inference(input_tensor, avg_class, weights1, biases1,
       weights2, biases2):
 
 # 当没有提供滑动平均类时,直接使用参数当前取值
 if avg_class == None:
  # 计算隐藏层前向传播结果
  layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
  # 计算输出层前向传播结果
  return tf.matmul(layer1, weights2) + biases2
 else:
  # 首先计算变量的滑动平均值,然后计算前向传播结果
  layer1 = tf.nn.relu(
    tf.matmul(input_tensor, avg_class.average(weights1)) +
    avg_class.average(biases1))
  
  return tf.matmul(
    layer1, avg_class.average(weights2)) + avg_class.average(biases2)
 
# 训练模型的过程
def train(mnist):
 x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
 y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
 
 # 生成隐藏层参数
 weights1 = tf.Variable(
   tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
 biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
 
 # 生成输出层参数
 weights2 = tf.Variable(
   tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
 biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
 
 # 计算前向传播结果,不使用参数滑动平均值 avg_class=None
 y = inference(x, None, weights1, biases1, weights2, biases2)
 
 # 定义训练轮数变量,指定为不可训练
 global_step = tf.Variable(0, trainable=False)
 
 # 给定滑动平均衰减率和训练轮数的变量,初始化滑动平均类
 variable_avgs = tf.train.ExponentialMovingAverage(
   MOVING_AVG_DECAY, global_step)
 
 # 在所有代表神经网络参数的可训练变量上使用滑动平均
 variables_avgs_op = variable_avgs.apply(tf.trainable_variables())
 
 # 计算使用滑动平均值后的前向传播结果
 avg_y = inference(x, variable_avgs, weights1, biases1, weights2, biases2)
 
 # 计算交叉熵作为损失函数
 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
   logits=y, labels=tf.argmax(y_, 1))
 cross_entropy_mean = tf.reduce_mean(cross_entropy)
 
 # 计算L2正则化损失函数
 regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
 regularization = regularizer(weights1) + regularizer(weights2)
 
 loss = cross_entropy_mean + regularization
 
 # 设置指数衰减的学习率
 learning_rate = tf.train.exponential_decay(
   LEARNING_RATE_BASE,
   global_step,              # 当前迭代轮数
   mnist.train.num_examples / BATCH_SIZE, # 过完所有训练数据的迭代次数
   LEARNING_RATE_DECAY)
 
 
 # 优化损失函数
 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(
   loss, global_step=global_step)
 
 # 反向传播同时更新神经网络参数及其滑动平均值
 with tf.control_dependencies([train_step, variables_avgs_op]):
  train_op = tf.no_op(name='train')
 
 # 检验使用了滑动平均模型的神经网络前向传播结果是否正确
 correct_prediction = tf.equal(tf.argmax(avg_y, 1), tf.argmax(y_, 1))
 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 
 
 # 初始化会话并开始训练
 with tf.Session() as sess:
  tf.global_variables_initializer().run()
  
  # 准备验证数据,用于判断停止条件和训练效果
  validate_feed = {x: mnist.validation.images,
          y_: mnist.validation.labels}
  
  # 准备测试数据,用于模型优劣的最后评价标准
  test_feed = {x: mnist.test.images, y_: mnist.test.labels}
  
  # 迭代训练神经网络
  for i in range(TRAINING_STEPS):
   if i%1000 == 0:
    validate_acc = sess.run(accuracy, feed_dict=validate_feed)
    print("After %d training step(s), validation accuracy using average " 
       "model is %g " % (i, validate_acc))
    
   xs, ys = mnist.train.next_batch(BATCH_SIZE)
   sess.run(train_op, feed_dict={x: xs, y_: ys})
  
  # 训练结束后在测试集上检测模型的最终正确率
  test_acc = sess.run(accuracy, feed_dict=test_feed)
  print("After %d training steps, test accuracy using average model "
     "is %g " % (TRAINING_STEPS, test_acc))
  
  
# 主程序入口
def main(argv=None):
 mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
 train(mnist)
 
# Tensorflow主程序入口
if __name__ == '__main__':
 tf.app.run()

输出结果如下:

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
After 0 training step(s), validation accuracy using average model is 0.0462 
After 1000 training step(s), validation accuracy using average model is 0.9784 
After 2000 training step(s), validation accuracy using average model is 0.9806 
After 3000 training step(s), validation accuracy using average model is 0.9798 
After 4000 training step(s), validation accuracy using average model is 0.9814 
After 5000 training step(s), validation accuracy using average model is 0.9826 
After 6000 training step(s), validation accuracy using average model is 0.9828 
After 7000 training step(s), validation accuracy using average model is 0.9832 
After 8000 training step(s), validation accuracy using average model is 0.9838 
After 9000 training step(s), validation accuracy using average model is 0.983 
After 10000 training step(s), validation accuracy using average model is 0.9836 
After 11000 training step(s), validation accuracy using average model is 0.9822 
After 12000 training step(s), validation accuracy using average model is 0.983 
After 13000 training step(s), validation accuracy using average model is 0.983 
After 14000 training step(s), validation accuracy using average model is 0.9844 
After 15000 training step(s), validation accuracy using average model is 0.9832 
After 16000 training step(s), validation accuracy using average model is 0.9844 
After 17000 training step(s), validation accuracy using average model is 0.9842 
After 18000 training step(s), validation accuracy using average model is 0.9842 
After 19000 training step(s), validation accuracy using average model is 0.9838 
After 20000 training step(s), validation accuracy using average model is 0.9834 
After 21000 training step(s), validation accuracy using average model is 0.9828 
After 22000 training step(s), validation accuracy using average model is 0.9834 
After 23000 training step(s), validation accuracy using average model is 0.9844 
After 24000 training step(s), validation accuracy using average model is 0.9838 
After 25000 training step(s), validation accuracy using average model is 0.9834 
After 26000 training step(s), validation accuracy using average model is 0.984 
After 27000 training step(s), validation accuracy using average model is 0.984 
After 28000 training step(s), validation accuracy using average model is 0.9836 
After 29000 training step(s), validation accuracy using average model is 0.9842 
After 30000 training steps, test accuracy using average model is 0.9839

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python写日志封装类实例
Jun 28 Python
详解Django中的过滤器
Jul 16 Python
Python之Web框架Django项目搭建全过程
May 02 Python
Python数据结构与算法之图的广度优先与深度优先搜索算法示例
Dec 14 Python
Windows 7下Python Web环境搭建图文教程
Mar 20 Python
Python实现中一次读取多个值的方法
Apr 22 Python
python更改已存在excel文件的方法
May 03 Python
python查看列的唯一值方法
Jul 17 Python
python中单例常用的几种实现方法总结
Oct 13 Python
如何在python中写hive脚本
Nov 08 Python
python wav模块获取采样率 采样点声道量化位数(实例代码)
Jan 22 Python
Python3利用scapy局域网实现自动多线程arp扫描功能
Jan 21 Python
Python3 读取Word文件方式
Feb 13 #Python
解决Python import docx出错DLL load failed的问题
Feb 13 #Python
python求最大公约数和最小公倍数的简单方法
Feb 13 #Python
python圣诞树编写实例详解
Feb 13 #Python
python如何实现复制目录到指定目录
Feb 13 #Python
Python制作简易版小工具之计算天数的实现思路
Feb 13 #Python
解决python-docx打包之后找不到default.docx的问题
Feb 13 #Python
You might like
Thinkphp的volist标签嵌套循环使用教程
2014/07/08 PHP
php专用数组排序类ArraySortUtil用法实例
2015/04/03 PHP
PHP扩展开发教程(总结)
2015/11/04 PHP
浅谈PHP中pack、unpack的详细用法
2018/03/12 PHP
laravel实现一个上传图片的接口,并建立软链接,访问图片的方法
2019/10/12 PHP
Yii 框架使用Forms操作详解
2020/05/18 PHP
javascript xml为数据源的下拉框控件
2009/07/07 Javascript
用Jquery实现可编辑表格并用AJAX提交到服务器修改数据
2009/12/27 Javascript
jQuery.lazyload+masonry改良图片瀑布流代码
2014/06/20 Javascript
详解AngularJS中自定义过滤器
2015/12/28 Javascript
javascript实现一个简单的弹出窗
2016/02/22 Javascript
Angular页面间切换及传值的4种方法
2016/11/04 Javascript
vue双向数据绑定原理探究(附demo)
2017/01/17 Javascript
jQuery实现复制到粘贴板功能
2017/02/11 Javascript
js中Object.defineProperty()方法的不详解
2018/07/09 Javascript
javascript关于“时间”的一次探索
2019/07/24 Javascript
解决vue的touchStart事件及click事件冲突问题
2020/07/21 Javascript
Vue+Element-U实现分页显示效果
2020/11/15 Javascript
[53:21]2014 DOTA2国际邀请赛中国区预选赛5.21 DT VS LGD-CDEC
2014/05/22 DOTA
Python迭代用法实例教程
2014/09/08 Python
python常见数制转换实例分析
2015/05/09 Python
Python删除Java源文件中全部注释的实现方法
2017/08/30 Python
Python设计模式之中介模式简单示例
2018/01/09 Python
解决pandas .to_excel不覆盖已有sheet的问题
2018/12/10 Python
python 实现矩阵上下/左右翻转,转置的示例
2019/01/23 Python
pyQt5实时刷新界面的示例
2019/06/25 Python
sqlalchemy实现时间列自动更新教程
2020/09/02 Python
Lands’ End官网:经典的美国生活方式品牌
2016/08/14 全球购物
材料采购员岗位职责
2013/12/17 职场文书
入党积极分子介绍信
2014/01/17 职场文书
关于责任的演讲稿
2014/05/20 职场文书
销售口号大全
2014/06/11 职场文书
学校政风行风评议工作总结
2014/10/21 职场文书
如何写贫困证明申请书
2014/10/29 职场文书
宪法宣传标语100条
2019/10/15 职场文书
五年级作文之劳动作文
2019/11/12 职场文书