TensorFlow MNIST手写数据集的实现方法


Posted in Python onFebruary 05, 2020

MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
save_path = 'model/'
saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    for i in range(total_batch):
      batch_xs, batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
      avg_cost += c / total_batch
    if (epoch + 1) % display_step == 0:
      print('epoch= ', epoch+1, ' cost= ', avg_cost)
  print('finished')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
  save = saver.save(sess, save_path=save_path+'mnist.cpkt')
print(" starting 2nd session ...... ")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver.restore(sess, save_path=save_path+'mnist.cpkt')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  output = tf.argmax(pred, 1)
  batch_xs, batch_ys = mnist.test.next_batch(2)
  outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
  print(outputval)
  im = batch_xs[0]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()
  im = batch_xs[1]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()

总结

以上所述是小编给大家介绍的TensorFlow MNIST手写数据集的实现方法,希望对大家有所帮助!

Python 相关文章推荐
在Django中限制已登录用户的访问的方法
Jul 23 Python
分享一下Python 开发者节省时间的10个方法
Oct 02 Python
Python编程中的文件读写及相关的文件对象方法讲解
Jan 19 Python
Python中字符串的常见操作技巧总结
Jul 28 Python
Python字符串处理实例详解
May 18 Python
使用python实现个性化词云的方法
Jun 16 Python
django 发送手机验证码的示例代码
Apr 25 Python
python爬取Ajax动态加载网页过程解析
Sep 05 Python
Python实现病毒仿真器的方法示例(附demo)
Feb 19 Python
python 实用工具状态机transitions
Nov 21 Python
Python调用系统命令os.system()和os.popen()的实现
Dec 31 Python
Python+pyaudio实现音频控制示例详解
Jul 23 Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
You might like
如何在PHP程序中防止盗链
2008/04/09 PHP
php实现简单文件下载的方法
2015/01/30 PHP
thinkphp配置文件路径的实现方法
2016/08/30 PHP
php与python实现的线程池多线程爬虫功能示例
2016/10/12 PHP
PHP接口继承及接口多继承原理与实现方法详解
2017/10/18 PHP
Laravel实现短信注册的示例代码
2018/05/29 PHP
javascript 字符串连接的性能问题(多浏览器)
2008/11/18 Javascript
jQuery中bind,live,delegate与one方法的用法及区别解析
2013/12/30 Javascript
javascript函数重载解决方案分享
2014/02/19 Javascript
三种检测iPhone/iPad设备方向的方法
2014/04/23 Javascript
让javascript加载速度倍增的方法(解决JS加载速度慢的问题)
2014/12/12 Javascript
jQuery动态添加
2016/04/07 Javascript
jQuery实现的placeholder效果完整实例
2016/08/02 Javascript
关于javascript原型的修改与重写(覆盖)差别详解
2016/08/31 Javascript
vue之nextTick全面解析
2017/05/17 Javascript
微信小程序实现下拉刷新和轮播图效果
2017/11/21 Javascript
浅析Vue 和微信小程序的区别、比较
2018/08/03 Javascript
nodejs基础之buffer缓冲区用法分析
2018/12/26 NodeJs
vue-cli脚手架引入弹出层layer插件的几种方法
2019/06/24 Javascript
JavaScript Dom 绑定事件操作实例详解
2019/10/02 Javascript
JS实现transform实现扇子效果
2020/01/17 Javascript
[44:15]国士无双DOTA2 6.82版本详解(上)
2014/09/28 DOTA
使用Protocol Buffers的C语言拓展提速Python程序的示例
2015/04/16 Python
Python实现的多线程http压力测试代码
2017/02/08 Python
Python设计模式之工厂方法模式实例详解
2019/01/18 Python
详解Python中pandas的安装操作说明(傻瓜版)
2019/04/08 Python
Python3内置模块之base64编解码方法详解
2019/07/13 Python
python实现全排列代码(回溯、深度优先搜索)
2020/02/26 Python
基于Python编写一个计算器程序,实现简单的加减乘除和取余二元运算
2020/08/05 Python
Python程序慢的重要原因
2020/09/04 Python
浅谈Html5页面打开app的一些思考
2020/03/30 HTML / CSS
电子商务应届生求职信
2013/11/16 职场文书
静心口服夜广告词
2014/03/20 职场文书
摄影展策划方案
2014/06/02 职场文书
县政府领导班子四风问题对照检查材料思想汇报
2014/09/26 职场文书
家长意见书
2015/06/04 职场文书