Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】


Posted in Python onDecember 19, 2019

本文实例讲述了Python tensorflow实现mnist手写数字识别。分享给大家供大家参考,具体如下:

非卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
weight = tf.Variable(tf.ones([784,10]))
bias = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(x_data ,weight) + bias)
#Y_model = tf.nn.sigmoid(tf.matmul(x_data ,weight) + bias)
'''
weight1 = tf.Variable(tf.ones([784,256]))
bias1 = tf.Variable(tf.ones([256]))
Y_model1 = tf.nn.softmax(tf.matmul(x_data ,weight1) + bias1)
weight1 = tf.Variable(tf.ones([256,10]))
bias1 = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(Y_model1 ,weight1) + bias1)
'''
y_data = tf.placeholder("float32", [None, 10])
loss = tf.reduce_sum(tf.pow((y_data - Y_model), 2 ))#92%-93%
#loss = tf.reduce_sum(tf.square(y_data - Y_model)) #90%-91%
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(100000):
  batch_xs, batch_ys = mnist_data.train.next_batch(50)
  sess.run(train, feed_dict = {x_data: batch_xs, y_data: batch_ys})
  if i%50==0:
    correct_predict = tf.equal(tf.arg_max(Y_model,1),tf.argmax(y_data,1))
    accurate = tf.reduce_mean(tf.cast(correct_predict,"float"))
    print(sess.run(accurate,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
x_image = tf.reshape(x_data, [-1,28,28,1])
w_conv = tf.Variable(tf.ones([5,5,1,32])) #weight
b_conv = tf.Variable(tf.ones([32]))    #bias
h_conv = tf.nn.relu(tf.nn.conv2d(x_image , w_conv,strides=[1,1,1,1],padding='SAME')+ b_conv)
h_pool = tf.nn.max_pool(h_conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
w_fc = tf.Variable(tf.ones([14*14*32,1024]))
b_fc = tf.Variable(tf.ones([1024]))
h_pool_flat = tf.reshape(h_pool,[-1,14*14*32])
h_fc = tf.nn.relu(tf.matmul(h_pool_flat,w_fc) +b_fc)
W_fc = w_fc = tf.Variable(tf.ones([1024,10]))
B_fc = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(h_fc,W_fc) +B_fc)
y_data = tf.placeholder("float32",[None,10])
loss = -tf.reduce_sum(y_data * tf.log(Y_model))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs,batch_ys =mnist_data.train.next_batch(5)
  sess.run(train_step,feed_dict={x_data:batch_xs,y_data:batch_ys})
  if i%50==0:
    correct_prediction = tf.equal(tf.argmax(Y_model,1),tf.argmax(y_data,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    print(sess.run(accuracy,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python类型强制转换long to int的代码
Feb 10 Python
按日期打印Python的Tornado框架中的日志的方法
May 02 Python
python简单实现基于SSL的IRC bot实例
Jun 15 Python
python使用Tkinter实现在线音乐播放器
Jan 30 Python
Python读取mat文件,并保存为pickle格式的方法
Oct 23 Python
django使用LDAP验证的方法示例
Dec 10 Python
Python实现对字典分别按键(key)和值(value)进行排序的方法分析
Dec 19 Python
把django中admin后台界面的英文修改为中文显示的方法
Jul 26 Python
使用批处理脚本自动生成并上传NuGet包(操作方法)
Nov 19 Python
python使用beautifulsoup4爬取酷狗音乐代码实例
Dec 04 Python
tensorflow 变长序列存储实例
Jan 20 Python
详解解Django 多对多表关系的三种创建方式
Aug 23 Python
Python: 传递列表副本方式
Dec 19 #Python
python内置模块collections知识点总结
Dec 19 #Python
Python操作redis和mongoDB的方法
Dec 19 #Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 #Python
实现Python与STM32通信方式
Dec 18 #Python
利用pandas将非数值数据转换成数值的方式
Dec 18 #Python
python 浅谈serial与stm32通信的编码问题
Dec 18 #Python
You might like
php生成xml简单实例代码
2009/12/16 PHP
php实现多站点共用session实现单点登录的方法详解
2019/09/18 PHP
js实现运行代码需要刷新的解决方法
2007/08/18 Javascript
javascript入门·图片对象(无刷新变换图片)\滚动图像
2007/10/01 Javascript
JQuery 文本框回车跳到下一个文本框示例代码
2013/08/30 Javascript
js实现页面跳转的五种方法推荐
2016/03/10 Javascript
JavaScript省市区三级联动菜单效果
2016/09/21 Javascript
AngularJS框架中的双向数据绑定机制详解【减少需要重复的开发代码量】
2017/01/19 Javascript
微信小程序之数据双向绑定与数据操作
2017/05/12 Javascript
简述vue状态管理模式之vuex
2018/08/29 Javascript
跨域解决之JSONP和CORS的详细介绍
2018/11/21 Javascript
JavaScript常见继承模式实例小结
2019/01/11 Javascript
微信小程序实现获取小程序码和二维码java接口开发
2019/03/29 Javascript
js实现点击图片在屏幕中间弹出放大效果
2019/09/11 Javascript
JavaScript数组排序小程序实现解析
2020/01/13 Javascript
JS字符串和数组如何实现相互转化
2020/07/02 Javascript
前端vue+elementUI如何实现记住密码功能
2020/09/20 Javascript
Js利用正则表达式去除字符串的中括号
2020/11/23 Javascript
Python用list或dict字段模式读取文件的方法
2017/01/10 Python
手把手教你用python抢票回家过年(代码简单)
2018/01/21 Python
python实现求特征选择的信息增益
2018/12/18 Python
Python常用特殊方法实例总结
2019/03/22 Python
Python datetime包函数简单介绍
2019/08/28 Python
详解python内置常用高阶函数(列出了5个常用的)
2020/02/21 Python
HTML5在IE10、火狐下中文乱码问题的解决方法
2013/11/18 HTML / CSS
车间统计员岗位职责
2014/01/05 职场文书
上课看小说检讨书
2014/02/22 职场文书
学习十八大坚定理想信念心得体会
2014/03/11 职场文书
医药销售自荐书
2014/05/29 职场文书
考博导师推荐信范文
2015/03/27 职场文书
2015年数学教研组工作总结
2015/05/23 职场文书
2015初中生物教研组工作总结
2015/07/21 职场文书
利用Nginx代理如何解决前端跨域问题详析
2021/04/02 Servers
防止web项目中的SQL注入
2021/12/06 MySQL
bootstrapv4轮播图去除两侧阴影及线框的方法
2022/02/15 HTML / CSS
Python获取指定日期是"星期几"的6种方法
2022/03/13 Python