TensorFlow MNIST手写数据集的实现方法


Posted in Python onFebruary 05, 2020

MNIST数据集介绍

MNIST数据集中包含了各种各样的手写数字图片,数据集的官网是:http://yann.lecun.com/exdb/mnist/index.html,我们可以从这里下载数据集。使用如下的代码对数据集进行加载:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

运行上述代码会自动下载数据集并将文件解压在MNIST_data文件夹下面。代码中的one_hot=True,表示将样本的标签转化为one_hot编码。

MNIST数据集中的图片是28*28的,每张图被转化为一个行向量,长度是28*28=784,每一个值代表一个像素点。数据集中共有60000张手写数据图片,其中55000张训练数据,5000张测试数据。

在MNIST中,mnist.train.images是一个形状为[55000, 784]的张量,其中的第一个维度是用来索引图片,第二个维度图片中的像素。MNIST数据集包含有三部分,训练数据集,验证数据集,测试数据集(mnist.validation)。

标签是介于0-9之间的数字,用于描述图片中的数字,转化为one-hot向量即表示的数字对应的下标为1,其余的值为0。标签的训练数据是[55000,10]的数字矩阵。

下面定义了一个简单的网络对数据集进行训练,代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
pred = tf.matmul(x, w) + b
pred = tf.nn.softmax(pred)
cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
training_epochs = 25
batch_size = 100
display_step = 1
save_path = 'model/'
saver = tf.train.Saver()
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = int(mnist.train.num_examples/batch_size)
    for i in range(total_batch):
      batch_xs, batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer, cost], feed_dict={x:batch_xs, y:batch_ys})
      avg_cost += c / total_batch
    if (epoch + 1) % display_step == 0:
      print('epoch= ', epoch+1, ' cost= ', avg_cost)
  print('finished')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x:mnist.test.images, y:mnist.test.labels}))
  save = saver.save(sess, save_path=save_path+'mnist.cpkt')
print(" starting 2nd session ...... ")
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  saver.restore(sess, save_path=save_path+'mnist.cpkt')
  correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print('accuracy: ', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
  output = tf.argmax(pred, 1)
  batch_xs, batch_ys = mnist.test.next_batch(2)
  outputval= sess.run([output], feed_dict={x:batch_xs, y:batch_ys})
  print(outputval)
  im = batch_xs[0]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()
  im = batch_xs[1]
  im = im.reshape(-1, 28)
  plt.imshow(im, cmap='gray')
  plt.show()

总结

以上所述是小编给大家介绍的TensorFlow MNIST手写数据集的实现方法,希望对大家有所帮助!

Python 相关文章推荐
Python文件和流(实例讲解)
Sep 12 Python
Python实现基本数据结构中队列的操作方法示例
Dec 04 Python
Python实现购物车购物小程序
Apr 18 Python
Jupyter notebook远程访问服务器的方法
May 24 Python
python实现nao机器人手臂动作控制
Apr 29 Python
python程序快速缩进多行代码方法总结
Jun 23 Python
python实现各种插值法(数值分析)
Jul 30 Python
Python字符串格式化输出代码实例
Nov 22 Python
python3 assert 断言的使用详解 (区别于python2)
Nov 27 Python
pymysql模块的操作实例
Dec 17 Python
如何将PySpark导入Python的放实现(2种)
Apr 26 Python
python爬虫构建代理ip池抓取数据库的示例代码
Sep 22 Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
解决Tensorflow 内存泄露问题
Feb 05 #Python
You might like
URL Rewrite的设置方法
2007/01/02 PHP
使用PHP的日期与时间函数技巧
2008/04/24 PHP
php ss7.5的数据调用 (笔记)
2010/03/08 PHP
PHP实现对文本数据库的常用操作方法实例演示
2014/07/04 PHP
Yii2.0高级框架数据库增删改查的一些操作
2015/11/16 PHP
php+redis实现商城秒杀功能
2020/11/19 PHP
javascript之dhDataGrid Ver2.0.0代码
2007/07/01 Javascript
js利用Array.splice实现Array的insert/remove
2009/01/13 Javascript
jQuery对象数据缓存Cache原理及jQuery.data方法区别介绍
2013/04/07 Javascript
封装好的javascript前端分页插件pagination
2016/01/04 Javascript
js和jquery实现监听键盘事件示例代码
2020/06/24 Javascript
BootStrap学习系列之Bootstrap Typeahead 组件实现百度下拉效果(续)
2016/07/07 Javascript
AngularJS使用ng-app自动加载bootstrap框架问题分析
2017/01/04 Javascript
实例分析JS与Node.js中的事件循环
2017/12/12 Javascript
d3绘制基本的柱形图的实现代码
2018/12/12 Javascript
JavaScript中concat复制数组方法浅析
2019/01/20 Javascript
VUE接入腾讯验证码功能(滑块验证)备忘
2019/05/07 Javascript
JavaScript实现飞舞的泡泡效果
2020/02/07 Javascript
浅谈实现在线预览PDF的几种解决办法
2020/08/10 Javascript
[02:23]DOTA2英雄基础教程 幻影长矛手
2013/12/09 DOTA
[01:32]2016国际邀请赛中国区预选赛CDEC战队教练采访
2016/06/26 DOTA
基于Python中capitalize()与title()的区别详解
2017/12/09 Python
Python基于滑动平均思想实现缺失数据填充的方法
2019/02/21 Python
Python中的self用法详解
2019/08/06 Python
python 求定积分和不定积分示例
2019/11/20 Python
pytorch 实现在预训练模型的 input上增减通道
2020/01/06 Python
python爬虫调度器用法及实例代码
2020/11/30 Python
CSS3美化表单控件全集
2016/06/29 HTML / CSS
师范生自我鉴定范文
2013/10/05 职场文书
西门豹教学反思
2014/02/04 职场文书
少先队学雷锋活动月总结
2014/03/09 职场文书
乒乓球兴趣小组活动总结
2014/07/08 职场文书
五心教育心得体会
2014/09/04 职场文书
四风问题个人对照检查剖析材料
2014/09/27 职场文书
浅析MongoDB之安全认证
2021/06/26 MongoDB
Vue3.0 手写放大镜效果
2021/07/25 Vue.js