Tensorflow训练MNIST手写数字识别模型


Posted in Python onFebruary 13, 2020

本文实例为大家分享了Tensorflow训练MNIST手写数字识别模型的具体代码,供大家参考,具体内容如下

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
 
INPUT_NODE = 784  # 输入层节点=图片像素=28x28=784
OUTPUT_NODE = 10  # 输出层节点数=图片类别数目
 
LAYER1_NODE = 500  # 隐藏层节点数,只有一个隐藏层
BATCH_SIZE = 100  # 一个训练包中的数据个数,数字越小
          # 越接近随机梯度下降,越大越接近梯度下降
 
LEARNING_RATE_BASE = 0.8   # 基础学习率
LEARNING_RATE_DECAY = 0.99  # 学习率衰减率
 
REGULARIZATION_RATE = 0.0001  # 正则化项系数
TRAINING_STEPS = 30000     # 训练轮数
MOVING_AVG_DECAY = 0.99    # 滑动平均衰减率
 
# 定义一个辅助函数,给定神经网络的输入和所有参数,计算神经网络的前向传播结果
def inference(input_tensor, avg_class, weights1, biases1,
       weights2, biases2):
 
 # 当没有提供滑动平均类时,直接使用参数当前取值
 if avg_class == None:
  # 计算隐藏层前向传播结果
  layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
  # 计算输出层前向传播结果
  return tf.matmul(layer1, weights2) + biases2
 else:
  # 首先计算变量的滑动平均值,然后计算前向传播结果
  layer1 = tf.nn.relu(
    tf.matmul(input_tensor, avg_class.average(weights1)) +
    avg_class.average(biases1))
  
  return tf.matmul(
    layer1, avg_class.average(weights2)) + avg_class.average(biases2)
 
# 训练模型的过程
def train(mnist):
 x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
 y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
 
 # 生成隐藏层参数
 weights1 = tf.Variable(
   tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
 biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
 
 # 生成输出层参数
 weights2 = tf.Variable(
   tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
 biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
 
 # 计算前向传播结果,不使用参数滑动平均值 avg_class=None
 y = inference(x, None, weights1, biases1, weights2, biases2)
 
 # 定义训练轮数变量,指定为不可训练
 global_step = tf.Variable(0, trainable=False)
 
 # 给定滑动平均衰减率和训练轮数的变量,初始化滑动平均类
 variable_avgs = tf.train.ExponentialMovingAverage(
   MOVING_AVG_DECAY, global_step)
 
 # 在所有代表神经网络参数的可训练变量上使用滑动平均
 variables_avgs_op = variable_avgs.apply(tf.trainable_variables())
 
 # 计算使用滑动平均值后的前向传播结果
 avg_y = inference(x, variable_avgs, weights1, biases1, weights2, biases2)
 
 # 计算交叉熵作为损失函数
 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
   logits=y, labels=tf.argmax(y_, 1))
 cross_entropy_mean = tf.reduce_mean(cross_entropy)
 
 # 计算L2正则化损失函数
 regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
 regularization = regularizer(weights1) + regularizer(weights2)
 
 loss = cross_entropy_mean + regularization
 
 # 设置指数衰减的学习率
 learning_rate = tf.train.exponential_decay(
   LEARNING_RATE_BASE,
   global_step,              # 当前迭代轮数
   mnist.train.num_examples / BATCH_SIZE, # 过完所有训练数据的迭代次数
   LEARNING_RATE_DECAY)
 
 
 # 优化损失函数
 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(
   loss, global_step=global_step)
 
 # 反向传播同时更新神经网络参数及其滑动平均值
 with tf.control_dependencies([train_step, variables_avgs_op]):
  train_op = tf.no_op(name='train')
 
 # 检验使用了滑动平均模型的神经网络前向传播结果是否正确
 correct_prediction = tf.equal(tf.argmax(avg_y, 1), tf.argmax(y_, 1))
 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 
 
 # 初始化会话并开始训练
 with tf.Session() as sess:
  tf.global_variables_initializer().run()
  
  # 准备验证数据,用于判断停止条件和训练效果
  validate_feed = {x: mnist.validation.images,
          y_: mnist.validation.labels}
  
  # 准备测试数据,用于模型优劣的最后评价标准
  test_feed = {x: mnist.test.images, y_: mnist.test.labels}
  
  # 迭代训练神经网络
  for i in range(TRAINING_STEPS):
   if i%1000 == 0:
    validate_acc = sess.run(accuracy, feed_dict=validate_feed)
    print("After %d training step(s), validation accuracy using average " 
       "model is %g " % (i, validate_acc))
    
   xs, ys = mnist.train.next_batch(BATCH_SIZE)
   sess.run(train_op, feed_dict={x: xs, y_: ys})
  
  # 训练结束后在测试集上检测模型的最终正确率
  test_acc = sess.run(accuracy, feed_dict=test_feed)
  print("After %d training steps, test accuracy using average model "
     "is %g " % (TRAINING_STEPS, test_acc))
  
  
# 主程序入口
def main(argv=None):
 mnist = input_data.read_data_sets("/tmp/data", one_hot=True)
 train(mnist)
 
# Tensorflow主程序入口
if __name__ == '__main__':
 tf.app.run()

输出结果如下:

Extracting /tmp/data/train-images-idx3-ubyte.gz
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
After 0 training step(s), validation accuracy using average model is 0.0462 
After 1000 training step(s), validation accuracy using average model is 0.9784 
After 2000 training step(s), validation accuracy using average model is 0.9806 
After 3000 training step(s), validation accuracy using average model is 0.9798 
After 4000 training step(s), validation accuracy using average model is 0.9814 
After 5000 training step(s), validation accuracy using average model is 0.9826 
After 6000 training step(s), validation accuracy using average model is 0.9828 
After 7000 training step(s), validation accuracy using average model is 0.9832 
After 8000 training step(s), validation accuracy using average model is 0.9838 
After 9000 training step(s), validation accuracy using average model is 0.983 
After 10000 training step(s), validation accuracy using average model is 0.9836 
After 11000 training step(s), validation accuracy using average model is 0.9822 
After 12000 training step(s), validation accuracy using average model is 0.983 
After 13000 training step(s), validation accuracy using average model is 0.983 
After 14000 training step(s), validation accuracy using average model is 0.9844 
After 15000 training step(s), validation accuracy using average model is 0.9832 
After 16000 training step(s), validation accuracy using average model is 0.9844 
After 17000 training step(s), validation accuracy using average model is 0.9842 
After 18000 training step(s), validation accuracy using average model is 0.9842 
After 19000 training step(s), validation accuracy using average model is 0.9838 
After 20000 training step(s), validation accuracy using average model is 0.9834 
After 21000 training step(s), validation accuracy using average model is 0.9828 
After 22000 training step(s), validation accuracy using average model is 0.9834 
After 23000 training step(s), validation accuracy using average model is 0.9844 
After 24000 training step(s), validation accuracy using average model is 0.9838 
After 25000 training step(s), validation accuracy using average model is 0.9834 
After 26000 training step(s), validation accuracy using average model is 0.984 
After 27000 training step(s), validation accuracy using average model is 0.984 
After 28000 training step(s), validation accuracy using average model is 0.9836 
After 29000 training step(s), validation accuracy using average model is 0.9842 
After 30000 training steps, test accuracy using average model is 0.9839

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python对两个有序列表进行合并和排序的例子
Jun 13 Python
Python中DJANGO简单测试实例
May 11 Python
python获取文件扩展名的方法
Jul 06 Python
Python中生成Epoch的方法
Apr 26 Python
Python SVM(支持向量机)实现方法完整示例
Jun 19 Python
详解python while 函数及while和for的区别
Sep 07 Python
Python实现图片转字符画的代码实例
Feb 22 Python
Python设置matplotlib.plot的坐标轴刻度间隔以及刻度范围
Jun 25 Python
Python处理时间日期坐标轴过程详解
Jun 25 Python
Win10下Python3.7.3安装教程图解
Jul 08 Python
对django2.0 关联表的必填on_delete参数的含义解析
Aug 09 Python
使用python接受tgam的脑波数据实例
Apr 09 Python
Python3 读取Word文件方式
Feb 13 #Python
解决Python import docx出错DLL load failed的问题
Feb 13 #Python
python求最大公约数和最小公倍数的简单方法
Feb 13 #Python
python圣诞树编写实例详解
Feb 13 #Python
python如何实现复制目录到指定目录
Feb 13 #Python
Python制作简易版小工具之计算天数的实现思路
Feb 13 #Python
解决python-docx打包之后找不到default.docx的问题
Feb 13 #Python
You might like
笑谈配置,使用Smarty技术
2007/01/04 PHP
php二维数组排序方法(array_multisort usort)
2013/12/25 PHP
PHP return语句的另一个作用
2014/07/30 PHP
php不使用copy()函数复制文件的方法
2015/03/13 PHP
Ubuntu上安装yaf扩展的方法
2018/01/29 PHP
laravel5实现微信第三方登录功能
2018/12/06 PHP
js动态添加onload、onresize、onscroll事件(另类方法)
2012/12/26 Javascript
3分钟写出来的Jquery版checkbox全选反选功能
2013/10/23 Javascript
table行随鼠标移动变色示例
2014/05/07 Javascript
js分页工具实例
2015/01/28 Javascript
jquery实现鼠标经过显示下划线的渐变下拉菜单效果代码
2015/08/24 Javascript
基于BootStrap Metronic开发框架经验小结【八】框架功能总体界面介绍
2016/05/12 Javascript
jquery遍历table的tr获取td的值实现方法
2016/05/19 Javascript
JS常见简单正则表达式验证功能小结【手机,地址,企业税号,金额,身份证等】
2017/01/22 Javascript
JavaScript组件开发之输入框加候选框
2017/03/10 Javascript
简单理解Vue中的nextTick方法
2018/01/30 Javascript
Layui表格行工具事件与数据回填方法
2019/09/13 Javascript
在vue中使用jsx语法的使用方法
2019/09/30 Javascript
vue 父组件通过$refs获取子组件的值和方法详解
2019/11/07 Javascript
微信小程序实现点击页面出现文字
2020/09/21 Javascript
node.js通过url读取文件
2020/10/16 Javascript
[42:25]2018DOTA2亚洲邀请赛 4.5 淘汰赛 LGD vs Liquid 第三场
2018/04/06 DOTA
web.py中调用文件夹内模板的方法
2014/08/26 Python
Python中使用scapy模拟数据包实现arp攻击、dns放大攻击例子
2014/10/23 Python
Python数据分析:手把手教你用Pandas生成可视化图表的教程
2018/12/15 Python
python 模拟贷款卡号生成规则过程解析
2019/08/30 Python
Python基本语法之运算符功能与用法详解
2019/10/22 Python
真正了解CSS3背景下的@font face规则
2017/05/04 HTML / CSS
在HTML5 canvas里用卷积核进行图像处理的方法
2018/05/02 HTML / CSS
英国最大的独立摄影零售商:Park Cameras
2019/11/27 全球购物
技术学校毕业生求职信分享
2013/12/02 职场文书
幼儿园大班评语大全
2014/04/17 职场文书
教师节标语大全
2014/10/07 职场文书
超市食品安全承诺书
2015/04/29 职场文书
喜迎建国70周年:有关爱国的名言名句
2019/09/24 职场文书
范文之农村基层党建工作报告
2019/10/24 职场文书