TensorFlow实现Softmax回归模型


Posted in Python onMarch 09, 2018

一、概述及完整代码

对MNIST(MixedNational Institute of Standard and Technology database)这个非常简单的机器视觉数据集,Tensorflow为我们进行了方便的封装,可以直接加载MNIST数据成我们期望的格式.本程序使用Softmax Regression训练手写数字识别的分类模型.

先看完整代码:

import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
 
mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 
print(mnist.train.images.shape, mnist.train.labels.shape) 
print(mnist.test.images.shape, mnist.test.labels.shape) 
print(mnist.validation.images.shape, mnist.validation.labels.shape) 
 
#构建计算图 
x = tf.placeholder(tf.float32, [None, 784]) 
W = tf.Variable(tf.zeros([784, 10])) 
b = tf.Variable(tf.zeros([10])) 
y = tf.nn.softmax(tf.matmul(x, W) + b) 
y_ = tf.placeholder(tf.float32, [None, 10]) 
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 
 
#在会话sess中启动图 
sess = tf.InteractiveSession() #创建InteractiveSession对象 
tf.global_variables_initializer().run() #全局参数初始化器 
for i in range(1000): 
 batch_xs, batch_ys = mnist.train.next_batch(100) 
 train_step.run({x: batch_xs, y_: batch_ys}) 
 
#测试验证阶段 
#沿着第1条轴方向取y和y_的最大值的索引并判断是否相等 
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 
#转换bool型tensor为float32型tensor并求平均即得到正确率 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

二、详细解读

首先看一下使用TensorFlow进行算法设计训练的核心步骤

1.定义算法公式,也就是神经网络forward时的计算;

2.定义loss,选定优化器,并制定优化器优化loss;

3.在训练集上迭代训练算法模型;

4.在测试集或验证集上对训练得到的模型进行准确率评测.

首先创建一个Placeholder,即输入张量数据的地方,第一个参数是数据类型dtype,第二个参数是tensor的形状shape.接下来创建SoftmaxRegression模型中的weights(W)和biases(b)的Variable对象,不同于存储数据的tensor一旦使用掉就会消失,Variable在模型训练迭代中是持久存在的,并且在每轮迭代中被更新Variable初始化可以是常量或随机值.接下来实现模型算法y = softmax(Wx + b),TensorFlow语言只需要一行代码,tf.nn包含了大量神经网络的组件,头tf.matmul是矩阵乘法函数.TensorFlow将模型中的forward和backward的内容都自动实现,只要定义好loss,训练的时候会自动求导并进行梯度下降,完成对模型参数的自动学习.定义损失函数lossfunction来描述分类精度,对于多分类问题通常使用cross-entropy交叉熵.先定义一个placeholder输入真实的label,tf.reduce_sum和tf.reduce_mean的功能分别是求和和求平均.构造完损失函数cross-entropy后,再定义一个优化算法即可开始训练.我们采用随机梯度下降SGD,定义好后TensorFlow会自动添加许多运算操作来实现反向传播和梯度下降,而给我们提供的是一个封装好的优化器,只需要每轮迭代时feed数据给它就好.设置好学习率.

构造阶段完成后, 才能启动图. 启动图的第一步是创建一个 Session 对象或InteractiveSession对象, 如果无任何创建参数, 会话构造器将启动默认图.创建InteractiveSession对象会这个Session注册为默认的Session,之后的运算也默认跑在这个Session里面,不同Session之间的数据和运算应该是相互独立的.下一步使用TensorFlow的全局参数初始化器tf.global_variables_initializer病直接执行它的run方法(这个全局参数初始化器应该是1.0.0版本中的新特性,在之前0.10.0版本测试不通过).

至此,以上定义的所有公式其实只是Computation Graph,代码执行到这时,计算还没有实际发生,只有等调用run方法并feed数据时计算才真正执行.

随后一步,就可以开始迭代地执行训练操作train_step.这里每次都从训练集中随机抽取100条样本构成一个mini-batch,并feed给placeholder.

完成迭代训练后,就可以对模型的准确率进行验证.比较y和y_在各个测试样本中最大值所在的索引,然后转换为float32型tensor后求平均即可得到正确率.多次测试后得到在测试集上的正确率为92%左右.还是比较理想的结果.

三、其他补充

1.Sesssion类和InteractiveSession类

对于product =tf.matmul(matrix1, matrix2),调用 sess 的 'run()' 方法来执行矩阵乘法 op, 传入 'product' 作为该方法的参数.上面提到, 'product' 代表了矩阵乘法 op 的输出, 传入它是向方法表明, 我们希望取回矩阵乘法 op 的输出.整个执行过程是自动化的, 会话负责传递op 所需的全部输入. op 通常是并发执行的.函数调用 'run(product)' 触发了图中三个 op (两个常量 op 和一个矩阵乘法 op)的执行.返回值 'result' 是一个 numpy的`ndarray`对象.

Session 对象在使用完后需要关闭以释放资源sess.close(). 除了显式调用 close 外, 也可以使用"with" 代码块 来自动完成关闭动作.

with tf.Session() as sess: 
 result = sess.run([product]) 
 print result

为了便于使用诸如 IPython 之类的 Python 交互环境, 可以使用InteractiveSession代替 Session 类, 使用 Tensor.eval()和 Operation.run()方法代替 Session.run(). 这样可以避免使用一个变量来持有会话.

# 进入一个交互式 TensorFlow 会话. 
import tensorflow as tf 
sess = tf.InteractiveSession() 
x = tf.Variable([1.0, 2.0]) 
a = tf.constant([3.0, 3.0]) 
# 使用初始化器 initializer op 的 run() 方法初始化 'x' 
x.initializer.run() 
# 增加一个减法 sub op, 从 'x' 减去 'a'. 运行减法 op, 输出结果 
sub = tf.sub(x, a) 
print sub.eval() 
# ==> [-2. -1.]

2.tf.reduce_sum

首先,tf.reduce_X一系列运算操作(operation)是实现对一个tensor各种减少维度的数学计算.

tf.reduce_sum(input_tensor, reduction_indices=None,keep_dims=False, name=None)

运算功能:沿着给定维度reduction_indices的方向降低input_tensor的维度,除非keep_dims=True,tensor的秩在reduction_indices上减1,被降低的维度的长度为1.如果reduction_indices没有传入参数,所有维度都降低,返回只含有1个元素的tensor.运算最终返回降维后的tensor.

演示代码:

# 'x' is [[1, 1, 1] 
#   [1, 1, 1]] 
tf.reduce_sum(x) ==> 6 
tf.reduce_sum(x, 0) ==> [2, 2, 2] 
tf.reduce_sum(x, 1) ==> [3, 3] 
tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]] 
tf.reduce_sum(x, [0, 1]) ==> 6

3.tf.reduce_mean

tf.reduce_mean(input_tensor, reduction_indices=None,keep_dims=False, name=None)

运算功能:将input_tensor沿着给定维度reduction_indices减少维度,除非keep_dims=True,tensor的秩在reduction_indices上减1,被降低的维度的长度为1.如果reduction_indices没有传入参数,所有维度都降低,返回只含有1个元素的tensor.运算最终返回降维后的tensor.

演示代码:

# 'x' is [[1., 1. ] 
#   [2., 2.]] 
tf.reduce_mean(x) ==> 1.5 
tf.reduce_mean(x, 0) ==> [1.5, 1.5] 
tf.reduce_mean(x, 1) ==> [1., 2.]

4.tf.argmax

tf.argmax(input, dimension, name=None)

运算功能:返回input在指定维度下的最大值的索引.返回类型为int64.

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之集合的关系
Sep 24 Python
python实现批量改文件名称的方法
May 25 Python
windows上安装Anaconda和python的教程详解
Mar 28 Python
Python之自动获取公网IP的实例讲解
Oct 01 Python
Python OpenCV中的resize()函数的使用
Jun 20 Python
python pandas模块基础学习详解
Jul 03 Python
Python selenium抓取虎牙短视频代码实例
Mar 02 Python
浅谈pymysql查询语句中带有in时传递参数的问题
Jun 05 Python
Keras-多输入多输出实例(多任务)
Jun 22 Python
python打开文件的方式有哪些
Jun 29 Python
Python游戏开发实例之graphics实现AI五子棋
Nov 01 Python
Python实现Hash算法
Mar 18 Python
用python实现百度翻译的示例代码
Mar 09 #Python
TensorFlow深度学习之卷积神经网络CNN
Mar 09 #Python
TensorFlow实现卷积神经网络CNN
Mar 09 #Python
新手常见6种的python报错及解决方法
Mar 09 #Python
Python 函数基础知识汇总
Mar 09 #Python
Python 使用with上下文实现计时功能
Mar 09 #Python
TensorFlow搭建神经网络最佳实践
Mar 09 #Python
You might like
CPU步进是什么意思?i3-9100F B0步进和U0步进区别知识科普
2020/03/17 数码科技
值得分享的php+ajax实时聊天室
2016/07/20 PHP
深入解析Laravel5.5中的包自动发现Package Auto Discovery
2017/09/13 PHP
利用Homestead快速运行一个Laravel项目的方法详解
2017/11/14 PHP
jQuery图片预加载 等比缩放实现代码
2011/10/04 Javascript
关于jquery性能最佳实践的讨论,与求教
2012/03/30 Javascript
js和jquery对dom节点的操作(创建/追加)
2013/04/21 Javascript
javascript中的document.open()方法使用介绍
2013/10/09 Javascript
浅谈JavaScript中指针和地址
2015/07/26 Javascript
Bootstrap实现登录校验表单(带验证码)
2016/06/23 Javascript
jQuery通过改变input的type属性实现密码显示隐藏切换功能
2017/02/08 Javascript
angular中的http拦截器Interceptors的实现
2017/02/21 Javascript
JavaScript登录记住密码操作(超简单代码)
2017/03/22 Javascript
jQuery实现简单的手风琴效果
2020/04/17 jQuery
angular学习之从零搭建一个angular4.0项目
2017/07/10 Javascript
基于ES6作用域和解构赋值详解
2017/11/03 Javascript
jQuery实现动态加载select下拉列表项功能示例
2018/05/31 jQuery
深入浅析Vue中的 computed 和 watch
2018/06/06 Javascript
JavaScript使用递归和循环实现阶乘的实例代码
2018/08/28 Javascript
简单说说如何使用vue-router插件的方法
2019/04/08 Javascript
vue项目,代码提交至码云,iconfont的用法说明
2020/07/30 Javascript
Python查询Mysql时返回字典结构的代码
2012/06/18 Python
Python基于二分查找实现求整数平方根的方法
2016/05/12 Python
Python中shape计算矩阵的方法示例
2017/04/21 Python
PyQt5 对图片进行缩放的实例
2019/06/18 Python
选择Python写网络爬虫的优势和理由
2019/07/07 Python
Python Numpy数组扩展repeat和tile使用实例解析
2019/12/09 Python
Python 一行代码能实现丧心病狂的功能
2020/01/18 Python
python之随机数函数的实现示例
2020/12/30 Python
详解HTML5 录音的踩坑之旅
2017/12/26 HTML / CSS
怎样写演讲稿
2014/01/04 职场文书
嘉宾邀请函
2015/01/31 职场文书
2016年第二十五次全国助残日活动总结
2016/04/01 职场文书
创业计划书之网吧
2019/10/10 职场文书
golang定时器
2022/04/14 Golang
Python 绘制多因子柱状图
2022/05/11 Python