基于Tensorflow的MNIST手写数字识别分类


Posted in Python onJune 17, 2020

本文实例为大家分享了基于Tensorflow的MNIST手写数字识别分类的具体实现代码,供大家参考,具体内容如下

代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.tensorboard.plugins import projector
import time

IMAGE_PIXELS = 28
hidden_unit = 100
output_nums = 10
learning_rate = 0.001
train_steps = 50000
batch_size = 500
test_data_size = 10000
#日志目录(这里根据自己的目录修改)
logdir = 'D:/Develop_Software/Anaconda3/WorkDirectory/summary/mnist'
#导入mnist数据
mnist = input_data.read_data_sets('MNIST_data', one_hot = True)

 #全局训练步数
global_step = tf.Variable(0, name = 'global_step', trainable = False)
with tf.name_scope('input'):
 #输入数据
 with tf.name_scope('x'):
 x = tf.placeholder(
  dtype = tf.float32, shape = (None, IMAGE_PIXELS * IMAGE_PIXELS))
 #收集x图像的会总数据
 with tf.name_scope('x_summary'):
 shaped_image_batch = tf.reshape(
  tensor = x,
  shape = (-1, IMAGE_PIXELS, IMAGE_PIXELS, 1),
  name = 'shaped_image_batch')
 tf.summary.image(name = 'image_summary',
      tensor = shaped_image_batch,
      max_outputs = 10)
 with tf.name_scope('y_'):
 y_ = tf.placeholder(dtype = tf.float32, shape = (None, 10))

with tf.name_scope('hidden_layer'):
 with tf.name_scope('hidden_arg'):
 #隐层模型参数
 with tf.name_scope('hid_w'):
  
  hid_w = tf.Variable(
   tf.truncated_normal(shape = (IMAGE_PIXELS * IMAGE_PIXELS, hidden_unit)),
   name = 'hidden_w')
  #添加获取隐层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = hid_w)
  with tf.name_scope('hid_b'):
  hid_b = tf.Variable(tf.zeros(shape = (1, hidden_unit), dtype = tf.float32),
       name = 'hidden_b')
 #隐层输出
 with tf.name_scope('relu'):
 hid_out = tf.nn.relu(tf.matmul(x, hid_w) + hid_b)
with tf.name_scope('softmax_layer'):
 with tf.name_scope('softmax_arg'):
 #softmax层参数
 with tf.name_scope('sm_w'):
  
  sm_w = tf.Variable(
   tf.truncated_normal(shape = (hidden_unit, output_nums)),
   name = 'softmax_w')
  #添加获取softmax层权重统计值汇总数据的汇总操作
  tf.summary.histogram(name = 'weights', values = sm_w)
  with tf.name_scope('sm_b'):
  sm_b = tf.Variable(tf.zeros(shape = (1, output_nums), dtype = tf.float32), 
       name = 'softmax_b')
 #softmax层的输出
 with tf.name_scope('softmax'):
 y = tf.nn.softmax(tf.matmul(hid_out, sm_w) + sm_b)
 #梯度裁剪,因为概率取值为[0, 1]为避免出现无意义的log(0),故将y值裁剪到[1e-10, 1]
 y_clip = tf.clip_by_value(y, 1.0e-10, 1 - 1.0e-5)
with tf.name_scope('cross_entropy'):
 #使用交叉熵代价函数
 cross_entropy = -tf.reduce_sum(y_ * tf.log(y_clip) + (1 - y_) * tf.log(1 - y_clip))
 #添加获取交叉熵的汇总操作
 tf.summary.scalar(name = 'cross_entropy', tensor = cross_entropy)
 
with tf.name_scope('train'):
 #若不使用同步训练机制,使用Adam优化器
 optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate)
 #单步训练操作,
 train_op = optimizer.minimize(cross_entropy, global_step = global_step)
#加载测试数据
test_image = mnist.test.images
test_label = mnist.test.labels
test_feed = {x:test_image, y_:test_label}

with tf.name_scope('accuracy'):
 prediction = tf.equal(tf.argmax(input = y, axis = 1),
      tf.argmax(input = y_, axis = 1))
 accuracy = tf.reduce_mean(
  input_tensor = tf.cast(x = prediction, dtype = tf.float32))
#创建嵌入变量
embedding_var = tf.Variable(test_image, trainable = False, name = 'embedding')
saver = tf.train.Saver({'embedding':embedding_var})
#创建元数据文件,将MNIST图像测试集对应的标签写入文件
def CreateMedaDataFile():
 with open(logdir + '/metadata.tsv', 'w') as f:
 label = np.nonzero(test_label)[1]
 for i in range(test_data_size):
  f.write('%d\n' % label[i])
#创建投影配置参数
def CreateProjectorConfig():
 config = projector.ProjectorConfig()
 embeddings = config.embeddings.add()
 embeddings.tensor_name = 'embedding:0'
 embeddings.metadata_path = logdir + '/metadata.tsv'
 
 projector.visualize_embeddings(writer, config)
 #聚集汇总操作
merged = tf.summary.merge_all()
#创建会话的配置参数
sess_config = tf.ConfigProto(
 allow_soft_placement = True,
 log_device_placement = False)
#创建会话
with tf.Session(config = sess_config) as sess:
 #创建FileWriter实例
 writer = tf.summary.FileWriter(logdir = logdir, graph = sess.graph)
 #初始化全局变量
 sess.run(tf.global_variables_initializer())
 time_begin = time.time()
 print('Training begin time: %f' % time_begin)
 while True:
 #加载训练批数据
 batch_x, batch_y = mnist.train.next_batch(batch_size)
 train_feed = {x:batch_x, y_:batch_y}
 loss, _, summary= sess.run([cross_entropy, train_op, merged], feed_dict = train_feed)
 step = global_step.eval()
 #如果step为100的整数倍
 if step % 100 == 0:
  now = time.time()
  print('%f: global_step = %d, loss = %f' % (
   now, step, loss))
  #向事件文件中添加汇总数据
  writer.add_summary(summary = summary, global_step = step)
 #若大于等于训练总步数,退出训练
 if step >= train_steps:
  break
 time_end = time.time()
 print('Training end time: %f' % time_end)
 print('Training time: %f' % (time_end - time_begin))
 #测试模型精度
 test_accuracy = sess.run(accuracy, feed_dict = test_feed)
 print('accuracy: %f' % test_accuracy)
 
 saver.save(sess = sess, save_path = logdir + '/embedding_var.ckpt')
 CreateMedaDataFile()
 CreateProjectorConfig()
 #关闭FileWriter
 writer.close()

基于Tensorflow的MNIST手写数字识别分类

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现全角半角转换的方法
Aug 18 Python
Python 装饰器深入理解
Mar 16 Python
python基础教程项目二之画幅好画
Apr 02 Python
python 通过麦克风录音 生成wav文件的方法
Jan 09 Python
Python2.7实现多进程下开发多线程示例
May 31 Python
Python数据类型之列表和元组的方法实例详解
Jul 08 Python
parser.add_argument中的action使用
Apr 20 Python
tensorflow使用freeze_graph.py将ckpt转为pb文件的方法
Apr 22 Python
使用Keras实现Tensor的相乘和相加代码
Jun 18 Python
matplotlib 三维图表绘制方法简介
Sep 20 Python
Python jieba库分词模式实例用法
Jan 13 Python
Python标准库pathlib操作目录和文件
Nov 20 Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
Python多线程threading创建及使用方法解析
Jun 17 #Python
Python偏函数Partial function使用方法实例详解
Jun 17 #Python
详解Python IO口多路复用
Jun 17 #Python
基于keras中的回调函数用法说明
Jun 17 #Python
Python学习之路安装pycharm的教程详解
Jun 17 #Python
Python闭包及装饰器运行原理解析
Jun 17 #Python
You might like
浅析HTTP消息头网页缓存控制以及header常用指令介绍
2013/06/28 PHP
php获取中文拼音首字母类和函数分享
2014/04/24 PHP
PHP内置过滤器FILTER使用实例
2014/06/25 PHP
weiphp微信公众平台授权设置
2016/01/04 PHP
Zend Framework教程之Resource Autoloading用法实例
2016/03/08 PHP
php中strlen和mb_strlen用法实例分析
2016/11/12 PHP
PHP代码加密的方法总结
2020/03/13 PHP
纯js实现瀑布流展现照片(自动适应窗口大小)
2013/04/08 Javascript
jquery实现树形菜单完整代码
2015/12/29 Javascript
jquery中validate与form插件提交的方式小结
2016/03/26 Javascript
js实现四舍五入完全保留两位小数的方法
2016/08/02 Javascript
利用iscroll4实现轮播图效果实例代码
2017/01/11 Javascript
jQuery实现的背景颜色渐变动画效果示例
2017/03/24 jQuery
5分钟打造简易高效的webpack常用配置
2017/07/04 Javascript
通过命令行创建vue项目的方法
2017/07/20 Javascript
Javascript实现秒表倒计时功能
2018/11/17 Javascript
bootstrap table列和表头对不齐的解决方法
2019/07/19 Javascript
jquery弹窗时禁止body滚动条滚动的例子
2019/09/21 jQuery
[48:23]DOTA2上海特级锦标赛主赛事日 - 4 败者组第四轮#1COL VS EG第一局
2016/03/05 DOTA
[43:18]NB vs Infamous 2019国际邀请赛淘汰赛 败者组 BO3 第一场 8.22
2019/09/05 DOTA
Python程序中的观察者模式结构编写示例
2016/05/27 Python
python daemon守护进程实现
2016/08/27 Python
对python numpy数组中冒号的使用方法详解
2018/04/17 Python
Python学习笔记之变量、自定义函数用法示例
2019/05/28 Python
html5 Canvas画图教程(5)—canvas里画曲线之arc方法
2013/01/09 HTML / CSS
英国拖鞋购买网站:Bedroom Athletics
2020/02/28 全球购物
医学专业毕业生推荐信
2014/07/12 职场文书
飞越疯人院观后感
2015/06/09 职场文书
长江七号观后感
2015/06/11 职场文书
初中班主任培训心得体会
2016/01/07 职场文书
Golang之sync.Pool使用详解
2021/05/06 Golang
记一次Mysql不走日期字段索引的原因小结
2021/10/24 MySQL
python游戏开发Pygame框架
2022/04/22 Python
Python如何利用pandas读取csv数据并绘图
2022/07/07 Python
mysql sock文件存储了什么信息
2022/07/15 MySQL
插件导致ECharts被全量引入的坑示例解析
2022/09/23 Javascript