tensorflow实现softma识别MNIST


Posted in Python onMarch 12, 2018

识别MNIST已经成了深度学习的hello world,所以每次例程基本都会用到这个数据集,这个数据集在tensorflow内部用着很好的封装,因此可以方便地使用。

这次我们用tensorflow搭建一个softmax多分类器,和之前搭建线性回归差不多,第一步是通过确定变量建立图模型,然后确定误差函数,最后调用优化器优化。

误差函数与线性回归不同,这里因为是多分类问题,所以使用了交叉熵。

另外,有一点值得注意的是,这里构建模型时我试图想拆分多个函数,但是后来发现这样做难度很大,因为图是在规定变量就已经定义好的,不能随意拆分,也不能当做变量传来传去,因此需要将他们写在一起。

代码如下:

#encoding=utf-8 
__author__ = 'freedom' 
import tensorflow as tf 
 
def loadMNIST(): 
 from tensorflow.examples.tutorials.mnist import input_data 
 mnist = input_data.read_data_sets('MNIST_data',one_hot=True) 
 return mnist 
 
def softmax(mnist,rate=0.01,batchSize=50,epoch=20): 
 n = 784 # 向量的维度数目 
 m = None # 样本数,这里可以获取,也可以不获取 
 c = 10 # 类别数目 
 
 x = tf.placeholder(tf.float32,[m,n]) 
 y = tf.placeholder(tf.float32,[m,c]) 
 
 w = tf.Variable(tf.zeros([n,c])) 
 b = tf.Variable(tf.zeros([c])) 
 
 pred= tf.nn.softmax(tf.matmul(x,w)+b) 
 loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) 
 opt = tf.train.GradientDescentOptimizer(rate).minimize(loss) 
 
 init = tf.initialize_all_variables() 
 
 sess = tf.Session() 
 sess.run(init) 
 for index in range(epoch): 
  avgLoss = 0 
  batchNum = int(mnist.train.num_examples/batchSize) 
  for batch in range(batchNum): 
   batch_x,batch_y = mnist.train.next_batch(batchSize) 
   _,Loss = sess.run([opt,loss],{x:batch_x,y:batch_y}) 
   avgLoss += Loss 
  avgLoss /= batchNum 
  print 'every epoch average loss is ',avgLoss 
 
 right = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) 
 accuracy = tf.reduce_mean(tf.cast(right,tf.float32)) 
 print 'Accracy is ',sess.run(accuracy,({x:mnist.test.images,y:mnist.test.labels})) 
 
 
if __name__ == "__main__": 
 mnist = loadMNIST() 
 softmax(mnist)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python构造自定义方法来美化字典结构输出的示例
Jun 16 Python
用python一行代码得到数组中某个元素的个数方法
Jan 28 Python
pyqt5 实现 下拉菜单 + 打开文件的示例代码
Jun 20 Python
对python3 sort sorted 函数的应用详解
Jun 27 Python
python读写配置文件操作示例
Jul 03 Python
python代码实现逻辑回归logistic原理
Aug 07 Python
python3的url编码和解码,自定义gbk、utf-8的例子
Aug 22 Python
Django异步任务线程池实现原理
Dec 17 Python
Python3+Selenium+Chrome实现自动填写WPS表单
Feb 12 Python
在Python 的线程中运行协程的方法
Feb 24 Python
python 基于PYMYSQL使用MYSQL数据库
Dec 24 Python
Python中Pyspider爬虫框架的基本使用详解
Jan 27 Python
wxpython实现图书管理系统
Mar 12 #Python
人生苦短我用python python如何快速入门?
Mar 12 #Python
tensorflow实现KNN识别MNIST
Mar 12 #Python
Python操作MySQL模拟银行转账
Mar 12 #Python
python3 图片referer防盗链的实现方法
Mar 12 #Python
tensorflow构建BP神经网络的方法
Mar 12 #Python
Python管理Windows服务小脚本
Mar 12 #Python
You might like
php代码优化及php相关问题总结
2006/10/09 PHP
PHP生成二维码的两个方法和实例
2014/07/01 PHP
PHP使用静态方法的几个注意事项
2014/09/16 PHP
php中使用sftp教程
2015/03/30 PHP
PHP Trait代码复用类与多继承实现方法详解
2019/06/17 PHP
JavaScript Eval 函数使用
2010/03/23 Javascript
fancybox1.3.1 基于Jquery的插件在IE中图片显示问题
2010/10/01 Javascript
js 判断一个元素是否在页面中存在
2012/12/27 Javascript
2012年开发人员的16款新鲜的jquery插件体验分享
2012/12/28 Javascript
Javascript自定义排序 node运行 实例
2013/06/05 Javascript
javascript实现可改变滚动方向的无缝滚动实例
2013/06/17 Javascript
js(jQuery)获取时间的方法及常用时间类搜集
2013/10/23 Javascript
Knockout数组(observable)使用详解示例
2013/11/15 Javascript
JavaScript的React框架中的JSX语法学习入门教程
2016/03/05 Javascript
jQuery给指定的table动态添加删除行的操作方法
2016/10/12 Javascript
解析预加载显示图片艺术
2016/12/05 Javascript
Vue2单一事件管理组件通信
2017/05/09 Javascript
BootstrapTable加载按钮功能实例代码详解
2017/09/22 Javascript
vue form check 表单验证的实现代码
2018/12/09 Javascript
vue源码中的检测方法的实现
2019/09/26 Javascript
JS数组扁平化、去重、排序操作实例详解
2020/02/24 Javascript
Python随机生成信用卡卡号的实现方法
2015/05/14 Python
python 中split 和 strip的实例详解
2017/07/12 Python
利用python循环创建多个文件的方法
2018/10/25 Python
python图像处理模块Pillow的学习详解
2019/10/09 Python
python opencv图片编码为h264文件的实例
2019/12/12 Python
Python 实现自动登录+点击+滑动验证功能
2020/06/10 Python
浅析pandas随机排列与随机抽样
2021/01/22 Python
详解Html5原生拖拽操作
2018/01/12 HTML / CSS
金融事务专业求职信
2014/04/25 职场文书
小学学校评估方案
2014/06/08 职场文书
岗位竞聘报告范文
2014/11/06 职场文书
2015年高校辅导员工作总结
2015/04/20 职场文书
什么是创业计划书?什么是商业计划书?这里一一解答
2019/07/12 职场文书
Redis 彻底禁用RDB持久化操作
2021/07/09 Redis
十大好看的穿越动漫排名:《瑞克和莫蒂》第一,国漫《有药》在榜
2022/03/18 日漫