Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】


Posted in Python onDecember 19, 2019

本文实例讲述了Python tensorflow实现mnist手写数字识别。分享给大家供大家参考,具体如下:

非卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
weight = tf.Variable(tf.ones([784,10]))
bias = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(x_data ,weight) + bias)
#Y_model = tf.nn.sigmoid(tf.matmul(x_data ,weight) + bias)
'''
weight1 = tf.Variable(tf.ones([784,256]))
bias1 = tf.Variable(tf.ones([256]))
Y_model1 = tf.nn.softmax(tf.matmul(x_data ,weight1) + bias1)
weight1 = tf.Variable(tf.ones([256,10]))
bias1 = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(Y_model1 ,weight1) + bias1)
'''
y_data = tf.placeholder("float32", [None, 10])
loss = tf.reduce_sum(tf.pow((y_data - Y_model), 2 ))#92%-93%
#loss = tf.reduce_sum(tf.square(y_data - Y_model)) #90%-91%
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init) # reset values to wrong
for i in range(100000):
  batch_xs, batch_ys = mnist_data.train.next_batch(50)
  sess.run(train, feed_dict = {x_data: batch_xs, y_data: batch_ys})
  if i%50==0:
    correct_predict = tf.equal(tf.arg_max(Y_model,1),tf.argmax(y_data,1))
    accurate = tf.reduce_mean(tf.cast(correct_predict,"float"))
    print(sess.run(accurate,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

卷积实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
data_path = 'F:\CNN\data\mnist'
mnist_data = input_data.read_data_sets(data_path,one_hot=True) #offline dataset
x_data = tf.placeholder("float32", [None, 784]) # None means we can import any number of images
x_image = tf.reshape(x_data, [-1,28,28,1])
w_conv = tf.Variable(tf.ones([5,5,1,32])) #weight
b_conv = tf.Variable(tf.ones([32]))    #bias
h_conv = tf.nn.relu(tf.nn.conv2d(x_image , w_conv,strides=[1,1,1,1],padding='SAME')+ b_conv)
h_pool = tf.nn.max_pool(h_conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
w_fc = tf.Variable(tf.ones([14*14*32,1024]))
b_fc = tf.Variable(tf.ones([1024]))
h_pool_flat = tf.reshape(h_pool,[-1,14*14*32])
h_fc = tf.nn.relu(tf.matmul(h_pool_flat,w_fc) +b_fc)
W_fc = w_fc = tf.Variable(tf.ones([1024,10]))
B_fc = tf.Variable(tf.ones([10]))
Y_model = tf.nn.softmax(tf.matmul(h_fc,W_fc) +B_fc)
y_data = tf.placeholder("float32",[None,10])
loss = -tf.reduce_sum(y_data * tf.log(Y_model))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs,batch_ys =mnist_data.train.next_batch(5)
  sess.run(train_step,feed_dict={x_data:batch_xs,y_data:batch_ys})
  if i%50==0:
    correct_prediction = tf.equal(tf.argmax(Y_model,1),tf.argmax(y_data,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    print(sess.run(accuracy,feed_dict={x_data:mnist_data.test.images,y_data:mnist_data.test.labels}))

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
剖析Django中模版标签的解析与参数传递
Jul 21 Python
Python logging管理不同级别log打印和存储实例
Jan 19 Python
python调用系统ffmpeg实现视频截图、http发送
Mar 06 Python
对python中矩阵相加函数sum()的使用详解
Jan 28 Python
Python第三方库h5py_读取mat文件并显示值的方法
Feb 08 Python
Django 表单模型选择框如何使用分组
May 16 Python
Python 实现数据结构中的的栈队列
May 16 Python
python每5分钟从kafka中提取数据的例子
Dec 23 Python
在Pytorch中计算卷积方法的区别详解(conv2d的区别)
Jan 03 Python
python扫描线填充算法详解
Feb 19 Python
python 实现Harris角点检测算法
Dec 11 Python
python blinker 信号库
May 04 Python
Python: 传递列表副本方式
Dec 19 #Python
python内置模块collections知识点总结
Dec 19 #Python
Python操作redis和mongoDB的方法
Dec 19 #Python
Python 实现Serial 与STM32J进行串口通讯
Dec 18 #Python
实现Python与STM32通信方式
Dec 18 #Python
利用pandas将非数值数据转换成数值的方式
Dec 18 #Python
python 浅谈serial与stm32通信的编码问题
Dec 18 #Python
You might like
一个php导出oracle库的php代码
2009/04/20 PHP
PHP 单引号与双引号的区别
2009/11/24 PHP
ThinkPHP自动验证失败的解决方法
2011/06/09 PHP
PHP 无限分类三种方式 非函数的递归调用!
2011/08/26 PHP
php中inlcude()性能对比详解
2012/09/16 PHP
带密匙的php加密解密示例分享
2014/01/29 PHP
php实现当前页面点击下载文件的实例代码
2016/11/16 PHP
php插入含有特殊符号数据的处理方法
2016/11/24 PHP
PHP实现的解汉诺塔问题算法示例
2018/08/06 PHP
jquery关于事件冒泡和事件委托的技巧及阻止与允许事件冒泡的三种实现方法
2015/11/27 Javascript
基于javascript实现彩票随机数生成(简单版)
2020/04/17 Javascript
AngularJS 2.0入门权威指南
2016/10/08 Javascript
vue组件实例解析
2017/01/10 Javascript
Angular 2 利用Router事件和Title实现动态页面标题的方法
2017/08/23 Javascript
详解如何使用koa实现socket.io官网的例子
2018/11/04 Javascript
layui使用表格渲染获取行数据的例子
2019/09/13 Javascript
Vue实现商品飞入购物车效果(电商项目)
2019/11/26 Javascript
解决vue做详情页跳转的时候使用created方法 数据不会更新问题
2020/07/24 Javascript
[02:02:38]VG vs Mineski Supermajor 败者组 BO3 第一场 6.6
2018/06/07 DOTA
Python中的异常处理简明介绍
2015/04/13 Python
windows系统下Python环境的搭建(Aptana Studio)
2017/03/06 Python
详解Python中for循环是如何工作的
2017/06/30 Python
Python图像处理之识别图像中的文字(实例讲解)
2018/05/10 Python
python去重,一个由dict组成的list的去重示例
2019/01/21 Python
Django时区详解
2019/07/24 Python
Python箱型图绘制与特征值获取过程解析
2019/10/22 Python
python GUI库图形界面开发之PyQt5控件数据拖曳Drag与Drop详细使用方法与实例
2020/02/27 Python
python GUI库图形界面开发之PyQt5表格控件QTableView详细使用方法与实例
2020/03/01 Python
Footshop乌克兰:运动鞋的最大选择
2019/12/01 全球购物
分别介绍一下Session Bean和Entity Bean
2015/03/13 面试题
应届毕业生简历自我评价
2014/01/31 职场文书
在职证明书模板
2015/06/15 职场文书
招商银行工作证明
2015/06/17 职场文书
2019年描写人生经典诗句大全
2019/07/08 职场文书
一条 SQL 语句执行过程
2022/03/17 MySQL
配置Kubernetes外网访问集群
2022/03/31 Servers