Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】


Posted in Python onDecember 19, 2019

本文实例讲述了Python tensorflow实现mnist手写数字识别。分享给大家供大家参考,具体如下:

非卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
weight = tf.Variable(tf.ones([784,10]))
bias = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(x_data ,weight) + bias)
#Y_model = tf.nn.sigmoid(tf.matmul(x_data ,weight) + bias)
'''
weight1 = tf.Variable(tf.ones([784,256]))
bias1 = tf.Variable(tf.ones([256]))
Y_model1 = tf.nn.softmax(tf.matmul(x_data ,weight1) + bias1)
weight1 = tf.Variable(tf.ones([256,10]))
bias1 = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(Y_model1 ,weight1) + bias1)
'''
y_data = tf.placeholder("float32", [None, 10])
loss = tf.reduce_sum(tf.pow((y_data - Y_model), 2 ))#92%-93%
#loss = tf.reduce_sum(tf.square(y_data - Y_model)) #90%-91%
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(100000):
  batch_xs, batch_ys = mnist_data.train.next_batch(50)
  sess.run(train, feed_dict = {x_data: batch_xs, y_data: batch_ys})
  if i%50==0:
    correct_predict = tf.equal(tf.arg_max(Y_model,1),tf.argmax(y_data,1))
    accurate = tf.reduce_mean(tf.cast(correct_predict,"float"))
    print(sess.run(accurate,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
x_image = tf.reshape(x_data, [-1,28,28,1])
w_conv = tf.Variable(tf.ones([5,5,1,32])) #weight
b_conv = tf.Variable(tf.ones([32]))    #bias
h_conv = tf.nn.relu(tf.nn.conv2d(x_image , w_conv,strides=[1,1,1,1],padding='SAME')+ b_conv)
h_pool = tf.nn.max_pool(h_conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
w_fc = tf.Variable(tf.ones([14*14*32,1024]))
b_fc = tf.Variable(tf.ones([1024]))
h_pool_flat = tf.reshape(h_pool,[-1,14*14*32])
h_fc = tf.nn.relu(tf.matmul(h_pool_flat,w_fc) +b_fc)
W_fc = w_fc = tf.Variable(tf.ones([1024,10]))
B_fc = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(h_fc,W_fc) +B_fc)
y_data = tf.placeholder("float32",[None,10])
loss = -tf.reduce_sum(y_data * tf.log(Y_model))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs,batch_ys =mnist_data.train.next_batch(5)
  sess.run(train_step,feed_dict={x_data:batch_xs,y_data:batch_ys})
  if i%50==0:
    correct_prediction = tf.equal(tf.argmax(Y_model,1),tf.argmax(y_data,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    print(sess.run(accuracy,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python 从远程服务器下载东西的代码
Feb 10 Python
简单介绍Python中的JSON使用
Apr 28 Python
在Python中处理字符串之isdigit()方法的使用
May 18 Python
Python多线程下载文件的方法
Jul 10 Python
Python构造自定义方法来美化字典结构输出的示例
Jun 16 Python
浅谈DataFrame和SparkSql取值误区
Jun 09 Python
使用Python来开发微信功能
Jun 13 Python
浅谈python 读excel数值为浮点型的问题
Dec 25 Python
jupyter notebook清除输出方式
Apr 10 Python
将pycharm配置为matlab或者spyder的用法说明
Jun 08 Python
深入理解Python 多线程
Jun 16 Python
Numpy 多维数据数组的实现
Jun 18 Python
Python: 传递列表副本方式
Dec 19 #Python
python内置模块collections知识点总结
Dec 19 #Python
Python操作redis和mongoDB的方法
Dec 19 #Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 #Python
实现Python与STM32通信方式
Dec 18 #Python
利用pandas将非数值数据转换成数值的方式
Dec 18 #Python
python 浅谈serial与stm32通信的编码问题
Dec 18 #Python
You might like
PHP定时自动生成静态HTML的实现代码
2010/06/20 PHP
php 截取字符串并以零补齐str_pad() 函数
2011/05/07 PHP
php在linux下检测mysql同步状态的方法
2015/01/15 PHP
基于PHP实现商品成交时发送短信功能
2016/05/11 PHP
Thinkphp实现站点静态化的方法详解
2017/03/21 PHP
php设计模式之组合模式实例详解【星际争霸游戏案例】
2020/03/27 PHP
Span元素的width属性无效果原因及解决方案
2010/01/15 Javascript
ASP.NET jQuery 实例12 通过使用jQuery validation插件简单实现用户注册页面验证功能
2012/02/03 Javascript
Javascript模块化编程(一)AMD规范(规范使用模块)
2013/01/17 Javascript
javaScript array(数组)使用字符串作为数组下标的方法
2013/11/19 Javascript
js获取客户端外网ip的简单实例
2013/11/21 Javascript
jQuery对Select的操作大集合(收藏)
2013/12/28 Javascript
CSS+JS实现点击文字弹出定时自动关闭DIV层菜单的方法
2015/05/12 Javascript
js实现文本框宽度自适应文本宽度的方法
2015/08/13 Javascript
javascript常用函数(1)
2015/11/04 Javascript
JavaScript数组和对象的复制
2017/03/21 Javascript
微信小程序-横向滑动scroll-view隐藏滚动条
2017/04/20 Javascript
JavaScript面向对象精要(下部)
2017/09/12 Javascript
Vue实现数字输入框中分割手机号码的示例
2017/10/10 Javascript
vuex 的简单使用
2018/03/22 Javascript
Vue 父子组件数据传递的四种方式( inheritAttrs + $attrs + $listeners)
2018/05/04 Javascript
Chart.js 轻量级HTML5图表绘制工具库(知识整理)
2018/05/22 Javascript
layer扩展打开/关闭动画的方法
2019/09/23 Javascript
es6函数之尾递归用法实例分析
2020/04/25 Javascript
Python基于回溯法子集树模板解决旅行商问题(TSP)实例
2017/09/05 Python
Python Selenium Cookie 绕过验证码实现登录示例代码
2018/04/10 Python
python复制列表时[:]和[::]之间有什么区别
2018/10/16 Python
python 格式化输出百分号的方法
2019/01/20 Python
Python定时任务APScheduler原理及实例解析
2020/05/30 Python
阿迪达斯印尼官方网站:adidas印尼
2020/02/10 全球购物
不开辟用于交换数据的临时空间,如何完成字符串的逆序
2012/12/02 面试题
汽车驾驶求职信
2013/10/25 职场文书
毕业生怎样写好自荐信
2013/11/11 职场文书
领导检查欢迎词
2014/01/14 职场文书
写求职信有什么意义
2014/02/17 职场文书
教师个人自我剖析材料
2014/09/29 职场文书