TensorFlow MNIST手写数据集的实现方法


Posted in Python onFebruary 05, 2020

MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
save_path = 'model/'
saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    for i in range(total_batch):
      batch_xs, batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
      avg_cost += c / total_batch
    if (epoch + 1) % display_step == 0:
      print('epoch= ', epoch+1, ' cost= ', avg_cost)
  print('finished')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
  save = saver.save(sess, save_path=save_path+'mnist.cpkt')
print(" starting 2nd session ...... ")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver.restore(sess, save_path=save_path+'mnist.cpkt')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  output = tf.argmax(pred, 1)
  batch_xs, batch_ys = mnist.test.next_batch(2)
  outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
  print(outputval)
  im = batch_xs[0]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()
  im = batch_xs[1]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()

总结

以上所述是小编给大家介绍的TensorFlow MNIST手写数据集的实现方法,希望对大家有所帮助!

Python 相关文章推荐
python进阶教程之模块(module)介绍
Aug 30 Python
Python-嵌套列表list的全面解析
Jun 08 Python
Python的消息队列包SnakeMQ使用初探
Jun 29 Python
python僵尸进程产生的原因
Jul 21 Python
Python列表list解析操作示例【整数操作、字符操作、矩阵操作】
Jul 25 Python
在Python dataframe中出生日期转化为年龄的实现方法
Oct 20 Python
如何在Cloud Studio上执行Python代码?
Aug 09 Python
Python 中如何实现参数化测试的方法示例
Dec 10 Python
python实现字符串和数字拼接
Mar 02 Python
Python3使用tesserocr识别字母数字验证码的实现
Jan 29 Python
浅谈Python列表嵌套字典转化的问题
Apr 07 Python
使用Python通过企业微信应用给企业成员发消息
Apr 18 Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
You might like
Linux下进行MYSQL编程时插入中文乱码的解决方案
2007/03/15 PHP
CI框架源码阅读,系统常量文件constants.php的配置
2013/02/28 PHP
深入php多态的实现详解
2013/06/09 PHP
PHP遍历数组的方法汇总
2015/04/30 PHP
实例介绍PHP删除数组中的重复元素
2019/03/03 PHP
超棒的javascript页面顶部卷动广告效果
2007/12/01 Javascript
jquery 新手学习常见问题解决方法
2010/04/18 Javascript
基于jsTree的无限级树JSON数据的转换代码
2010/07/27 Javascript
jQuery使用数组编写图片无缝向左滚动
2012/12/11 Javascript
js动态修改input输入框的type属性(实现方法解析)
2013/11/13 Javascript
js为什么不能正确处理小数运算?
2015/12/29 Javascript
Javascript中神奇的this
2016/01/20 Javascript
node.js 中国天气预报 简单实现
2016/06/06 Javascript
Angularjs 事件指令详细整理
2017/07/27 Javascript
Vue Cli与BootStrap结合实现表格分页功能
2017/08/18 Javascript
解决Vue打包之后文件路径出错的问题
2018/03/06 Javascript
element-ui点击查看大图的方法示例
2020/12/14 Javascript
在Angular项目使用socket.io实现通信的方法
2021/01/05 Javascript
python爬虫入门教程之点点美女图片爬虫代码分享
2014/09/02 Python
Python的Flask框架中SQLAlchemy使用时的乱码问题解决
2015/11/07 Python
python中pygame针对游戏窗口的显示方法实例分析(附源码)
2015/11/11 Python
Python学习笔记之if语句的使用示例
2017/10/23 Python
利用pyuic5将ui文件转换为py文件的方法
2019/06/19 Python
Python 循环终止语句的三种方法小结
2019/06/24 Python
Python IDE Pycharm中的快捷键列表用法
2019/08/08 Python
Python 操作 PostgreSQL 数据库示例【连接、增删改查等】
2020/04/21 Python
tensorflow图像裁剪进行数据增强操作
2020/06/30 Python
你所知道的集合类都有哪些?主要方法?
2012/12/31 面试题
幼儿老师求职信
2014/06/30 职场文书
客户经理岗位职责大全
2015/04/09 职场文书
2015年街道除四害工作总结
2015/05/15 职场文书
光荣之路观后感
2015/06/12 职场文书
经典爱情感言
2015/08/03 职场文书
《鸡兔同笼》教学反思
2016/02/19 职场文书
python字典进行运算原理及实例分享
2021/08/02 Python
Ruby处理CSV数据方法详解
2022/04/18 Ruby