基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python判断字符串与大小写转换
Jun 08 Python
使用Python中的tkinter模块作图的方法
Feb 07 Python
使用Python操作excel文件的实例代码
Oct 15 Python
django之session与分页(实例讲解)
Nov 13 Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 Python
python pygame模块编写飞机大战
Nov 20 Python
python实现对任意大小图片均匀切割的示例
Dec 05 Python
python网络编程socket实现服务端、客户端操作详解
Mar 24 Python
Python使用graphviz画流程图过程解析
Mar 31 Python
浅谈keras通过model.fit_generator训练模型(节省内存)
Jun 17 Python
pytorch 计算Parameter和FLOP的操作
Mar 04 Python
pytorch中的torch.nn.Conv2d()函数图文详解
Feb 28 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
PHP得到某段时间区间的时间戳 php定时任务
2012/04/12 PHP
php判断正常访问和外部访问的示例
2014/02/10 PHP
学习php设计模式 php实现抽象工厂模式
2015/12/07 PHP
form自动提交实例讲解
2017/07/10 PHP
thinkphp ajaxfileupload实现异步上传图片的示例
2017/08/28 PHP
php+lottery.js实现九宫格抽奖功能
2019/07/21 PHP
javascript 数组的方法集合
2008/06/05 Javascript
Discuz! 6.1_jQuery兼容问题
2008/09/23 Javascript
Javascript 命名空间模式
2013/11/01 Javascript
JQuery实现展开关闭层的方法
2015/02/17 Javascript
jQuery插件之jQuery.Form.js用法实例分析(附demo示例源码)
2016/01/04 Javascript
JavaScript表单验证实例之验证表单项是否为空
2016/01/10 Javascript
javascript断点调试心得分享
2016/04/23 Javascript
Bootstrap 组件之按钮(二)
2016/05/11 Javascript
jQuery控制控件文本的长度的操作方法
2016/12/05 Javascript
Angularjs根据json文件动态生成路由状态的实现方法
2017/04/17 Javascript
jquery加载单文件vue组件的方法
2017/06/20 jQuery
基于JavaScript实现无缝滚动效果
2017/07/21 Javascript
vue引入ueditor及node后台配置详解
2018/01/03 Javascript
基于vue v-for 循环复选框-默认勾选第一个的实现方法
2018/03/03 Javascript
JS伪继承prototype实现方法示例
2018/06/20 Javascript
layer插件select选中默认值的方法
2018/08/14 Javascript
基于Three.js实现360度全景图片
2018/12/30 Javascript
vue vantUI tab切换时 list组件不触发load事件的问题及解决方法
2020/02/14 Javascript
vue $mount 和 el的区别说明
2020/09/11 Javascript
Python基础教程之内置函数locals()和globals()用法分析
2018/03/16 Python
python爬虫之自动登录与验证码识别
2020/06/15 Python
Python面向对象基础入门之编码细节与注意事项
2018/12/11 Python
对python实现模板生成脚本的方法详解
2019/01/30 Python
如何基于Python制作有道翻译小工具
2019/12/16 Python
Python plt 利用subplot 实现在一张画布同时画多张图
2021/02/26 Python
一款CSS3实现多功能下拉菜单(带分享按)的教程
2014/11/05 HTML / CSS
美国现代家具和家居商店:Apt2B
2016/08/29 全球购物
爱护公物主题班会
2015/08/17 职场文书
高中物理教学反思
2016/02/19 职场文书
开发者首先否认《遗弃》被取消的传言
2022/04/11 其他游戏