基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅析Python中将单词首字母大写的capitalize()方法
May 18 Python
Python 爬虫的工具列表大全
Jan 31 Python
在java中如何定义一个抽象属性示例详解
Aug 18 Python
使用Python抓取豆瓣影评数据的方法
Oct 17 Python
Django中reverse反转并且传递参数的方法
Aug 06 Python
django多种支付、并发订单处理实例代码
Dec 13 Python
PyCharm永久激活方式(推荐)
Sep 22 Python
python和c语言哪个更适合初学者
Jun 22 Python
Python实现PS滤镜中的USM锐化效果
Dec 04 Python
Python 求向量的余弦值操作
Mar 04 Python
tensorboard 可视化之localhost:6006不显示的解决方案
May 22 Python
pytorch 两个GPU同时训练的解决方案
Jun 01 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
GD输出汉字的函数的分析
2006/10/09 PHP
php性能优化分析工具XDebug 大型网站调试工具
2011/05/22 PHP
PHP 抽象方法与抽象类abstract关键字介绍及应用
2014/10/16 PHP
php反序列化长度变化尾部字符串逃逸(0CTF-2016-piapiapia)
2020/02/15 PHP
CLASS_CONFUSION JS混淆 全源码
2007/12/12 Javascript
ExtJS 2.0实用简明教程 之Ext类库简介
2009/04/29 Javascript
利用jquery操作Radio方法小结
2014/10/20 Javascript
JavaScript编程中布尔对象的基本使用
2015/10/25 Javascript
个人网站留言页面(前端jQuery编写、后台php读写MySQL)
2016/05/03 Javascript
论Bootstrap3和Foundation5网格系统的异同
2016/05/16 Javascript
Bootstrap Paginator分页插件与ajax相结合实现动态无刷新分页效果
2016/05/27 Javascript
原生js实现网易轮播图效果
2020/04/10 Javascript
JS实现根据用户输入分钟进行倒计时功能
2016/11/14 Javascript
使用Angular CLI生成路由的方法
2018/03/24 Javascript
基于React Native 0.52实现轮播图效果
2020/08/25 Javascript
详解使用React制作一个模态框
2019/03/14 Javascript
微信小程序-form表单提交代码实例
2019/04/29 Javascript
vue+elementUI 实现内容区域高度自适应的示例
2020/09/26 Javascript
详解vue中在父组件点击按钮触发子组件的事件
2020/11/13 Javascript
在Python的Django框架的视图中使用Session的方法
2015/07/23 Python
Python基于回溯法子集树模板解决野人与传教士问题示例
2017/09/11 Python
python opencv 图像尺寸变换方法
2018/04/02 Python
用Python将mysql数据导出成json的方法
2018/08/21 Python
python 堆和优先队列的使用详解
2019/03/05 Python
Flask使用Pyecharts在单个页面展示多个图表的方法
2019/08/05 Python
在tensorflow中实现去除不足一个batch的数据
2020/01/20 Python
详解python 支持向量机(SVM)算法
2020/09/18 Python
Pycharm plot独立窗口显示的操作
2020/12/11 Python
Html5 webRTC简单实现视频调用的示例代码
2020/09/23 HTML / CSS
John Hardy官方网站:手工设计首饰的奢侈品牌
2017/07/05 全球购物
俄罗斯马克西多姆家居用品网上商店:Максидом
2020/02/06 全球购物
十佳大学生村官事迹
2014/01/09 职场文书
2014年教师党员自我评价范文
2014/09/22 职场文书
2014年学生会部门工作总结
2014/11/07 职场文书
归元寺导游词
2015/02/06 职场文书
【海涛教你打DOTA】虚空假面第一视角骨弓3房29杀
2022/04/01 DOTA