基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python list转dict示例分享
Jan 28 Python
给Python IDLE加上自动补全和历史功能
Nov 30 Python
Mac中Python 3环境下安装scrapy的方法教程
Oct 26 Python
TF-IDF与余弦相似性的应用(二) 找出相似文章
Dec 21 Python
python3中函数参数的四种简单用法
Jul 09 Python
Python txt文件加入字典并查询的方法
Jan 15 Python
django 自定义filter 判断if var in list的例子
Aug 20 Python
详解Python实现进度条的4种方式
Jan 15 Python
tensorflow查看ckpt各节点名称实例
Jan 21 Python
使用Django实现把两个模型类的数据聚合在一起
Mar 28 Python
python中rb含义理解
Jun 18 Python
Python 利用Entrez库筛选下载PubMed文献摘要的示例
Nov 24 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
php中的时间显示
2007/01/18 PHP
php+javascript的日历控件
2009/11/19 PHP
yii框架配置默认controller和action示例
2014/04/30 PHP
YII中assets的使用示例
2014/07/31 PHP
php实现mysql事务处理的方法
2014/12/25 PHP
ajax+php实现无刷新验证手机号的实例
2017/12/22 PHP
php实现构建排除当前元素的乘积数组方法
2018/10/06 PHP
PHP 加密 Password Hashing API基础知识点
2020/03/02 PHP
jQuery中ajax的使用与缓存问题的解决方法
2013/12/19 Javascript
jquery控制表单输入框显示默认值的方法
2015/05/22 Javascript
深入学习jQuery Validate表单验证
2016/01/18 Javascript
ajax图片上传,图片异步上传,更新实例
2016/12/30 Javascript
详解用webpack2.0构建vue2.0超详细精简版
2017/04/05 Javascript
jQuery Json数据格式排版高亮插件json-viewer.js使用方法详解
2017/06/12 jQuery
vue中实现methods一个方法调用另外一个方法
2018/02/08 Javascript
vue-cli 使用axios的操作方法及整合axios的多种方法
2018/09/12 Javascript
详解多页应用 Webpack4 配置优化与踩坑记录
2018/10/16 Javascript
微信小程序全局变量改变监听的实现方法
2019/07/15 Javascript
实例分析JS中的相等性判断===、 ==和Object.is()
2019/11/17 Javascript
使用vant的地域控件追加全部选项
2020/11/03 Javascript
[01:00:52]2018DOTA2亚洲邀请赛 4.4 淘汰赛 EG vs LGD 第一场
2018/04/05 DOTA
基于python实现的百度新歌榜、热歌榜下载器(附代码)
2019/08/05 Python
Python集合基本概念与相关操作实例分析
2019/10/30 Python
python实现秒杀商品的微信自动提醒功能(代码详解)
2020/04/27 Python
判断Threading.start新线程是否执行完毕的实例
2020/05/02 Python
实例讲解Python 迭代器与生成器
2020/07/08 Python
Django如何实现防止XSS攻击
2020/10/13 Python
Python使用eval函数执行动态标表达式过程详解
2020/10/17 Python
CSS3 二级导航菜单的制作的示例
2018/04/02 HTML / CSS
用HTML5.0制作网页的教程
2010/05/30 HTML / CSS
澳大利亚运动鞋零售商:The Athlete’s Foot
2018/11/04 全球购物
美国在线鞋类零售商:LifeStride
2019/06/09 全球购物
物业招聘计划书
2014/01/10 职场文书
2015年勤工助学工作总结
2015/04/29 职场文书
《合作意向书》怎么写?
2019/08/20 职场文书
MongoDB balancer的使用详解
2021/04/30 MongoDB