TensorFlow MNIST手写数据集的实现方法


Posted in Python onFebruary 05, 2020

MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
save_path = 'model/'
saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    for i in range(total_batch):
      batch_xs, batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
      avg_cost += c / total_batch
    if (epoch + 1) % display_step == 0:
      print('epoch= ', epoch+1, ' cost= ', avg_cost)
  print('finished')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
  save = saver.save(sess, save_path=save_path+'mnist.cpkt')
print(" starting 2nd session ...... ")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver.restore(sess, save_path=save_path+'mnist.cpkt')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  output = tf.argmax(pred, 1)
  batch_xs, batch_ys = mnist.test.next_batch(2)
  outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
  print(outputval)
  im = batch_xs[0]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()
  im = batch_xs[1]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()

总结

以上所述是小编给大家介绍的TensorFlow MNIST手写数据集的实现方法,希望对大家有所帮助!

Python 相关文章推荐
python基于xmlrpc实现二进制文件传输的方法
Jun 02 Python
Python验证码识别的方法
Jul 10 Python
浅谈python可视化包Bokeh
Feb 07 Python
Pandas 合并多个Dataframe(merge,concat)的方法
Jun 08 Python
Python操作mongodb数据库的方法详解
Dec 08 Python
Python List cmp()知识点总结
Feb 18 Python
Python3获取拉勾网招聘信息的方法实例
Apr 03 Python
python写入数据到csv或xlsx文件的3种方法
Aug 23 Python
利用Python脚本实现自动刷网课
Feb 03 Python
python实现126邮箱发送邮件
May 20 Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 Python
Python sqlalchemy时间戳及密码管理实现代码详解
Aug 01 Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
You might like
如何隐藏你的.php文件
2007/01/04 PHP
php开发文档 会员收费1期
2012/08/14 PHP
ThinkPHP的URL重写问题
2014/06/22 PHP
php相对当前文件include其它文件的方法
2015/03/13 PHP
javascript笔记 String类replace函数的一些事
2011/09/22 Javascript
Extjs中使用extend(js继承) 的代码
2012/03/15 Javascript
JavaScript关闭当前页面(窗口)不带任何提示
2014/03/26 Javascript
Jquery和BigFileUpload实现大文件上传及进度条显示
2016/06/27 Javascript
后端接收不到AngularJs中$http.post发送的数据原因分析及解决办法
2016/07/05 Javascript
深入理解ES7的async/await的用法
2017/09/09 Javascript
JS实现仿微信支付弹窗功能
2018/06/25 Javascript
vue 实现axios拦截、页面跳转和token 验证
2018/07/17 Javascript
Vue 权限控制的两种方法(路由验证)
2019/08/16 Javascript
Vue实现鼠标经过文字显示悬浮框效果的示例代码
2020/10/14 Javascript
使用scrapy实现爬网站例子和实现网络爬虫(蜘蛛)的步骤
2014/01/23 Python
详解在Python程序中使用Cookie的教程
2015/04/30 Python
python交互式图形编程实例(一)
2017/11/17 Python
python中requests和https使用简单示例
2018/01/18 Python
Python实现上下班抢个顺风单脚本
2018/02/07 Python
python实现输入数字的连续加减方法
2018/06/22 Python
Python 分享10个PyCharm技巧
2019/07/13 Python
通过Turtle库在Python中绘制一个鼠年福鼠
2020/02/03 Python
Python代码一键转Jar包及Java调用Python新姿势
2020/03/10 Python
Python基于Socket实现简易多人聊天室的示例代码
2020/11/29 Python
房地产还款计划书
2014/01/10 职场文书
入党自我评价范文
2014/02/02 职场文书
社区食品安全实施方案
2014/03/28 职场文书
学校师德承诺书
2014/05/23 职场文书
医药销售自我评价200字
2014/09/11 职场文书
中职毕业生自我鉴定
2014/09/13 职场文书
2014年除四害工作总结
2014/12/06 职场文书
承诺函格式模板
2015/01/21 职场文书
证券区域经理岗位职责
2015/04/10 职场文书
培训通知书模板
2015/04/17 职场文书
读《解忧杂货店》有感:请相信一切都是最好的安排
2019/11/07 职场文书
JavaScript 原型与原型链详情
2021/11/02 Javascript