在Tensorflow中实现梯度下降法更新参数值


Posted in Python onJanuary 23, 2020

我就废话不多说了,直接上代码吧!

tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

TensorFlow经过使用梯度下降法对损失函数中的变量进行修改值,默认修改tf.Variable(tf.zeros([784,10]))

为Variable的参数。

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy,var_list=[w,b])

也可以使用var_list参数来定义更新那些参数的值

#导入Minst数据集
import input_data
mnist = input_data.read_data_sets("data",one_hot=True)
 
#导入tensorflow库
import tensorflow as tf
 
#输入变量,把28*28的图片变成一维数组(丢失结构信息)
x = tf.placeholder("float",[None,784])
 
#权重矩阵,把28*28=784的一维输入,变成0-9这10个数字的输出
w = tf.Variable(tf.zeros([784,10]))
#偏置
b = tf.Variable(tf.zeros([10]))
 
#核心运算,其实就是softmax(x*w+b)
y = tf.nn.softmax(tf.matmul(x,w) + b)
 
#这个是训练集的正确结果
y_ = tf.placeholder("float",[None,10])
 
#交叉熵,作为损失函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
 
#梯度下降算法,最小化交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
 
#初始化,在run之前必须进行的
init = tf.initialize_all_variables()
#创建session以便运算
sess = tf.Session()
sess.run(init)
 
#迭代1000次
for i in range(1000):
 #获取训练数据集的图片输入和正确表示数字
 batch_xs, batch_ys = mnist.train.next_batch(100)
 #运行刚才建立的梯度下降算法,x赋值为图片输入,y_赋值为正确的表示数字
 sess.run(train_step,feed_dict = {x:batch_xs, y_: batch_ys})
 
#tf.argmax获取最大值的索引。比较运算后的结果和本身结果是否相同。
#这步的结果应该是[1,1,1,1,1,1,1,1,0,1...........1,1,0,1]这种形式。
#1代表正确,0代表错误
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
 
#tf.cast先将数据转换成float,防止求平均不准确。
#tf.reduce_mean由于只有一个参数,就是上面那个数组的平均值。
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
#输出
print(sess.run(accuracy,feed_dict={x:mnist.test.images,y_: mnist.test.labels}))

计算结果如下

"C:\Program Files\Anaconda3\python.exe" D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py
Extracting data\train-images-idx3-ubyte.gz
Extracting data\train-labels-idx1-ubyte.gz
Extracting data\t10k-images-idx3-ubyte.gz
Extracting data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-05-14 15:49:45.866600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2018-05-14 15:49:45.866600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
0.9163
 
Process finished with exit code 0

如果限制,只更新参数W查看效果

"C:\Program Files\Anaconda3\python.exe" D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py
Extracting data\train-images-idx3-ubyte.gz
Extracting data\train-labels-idx1-ubyte.gz
Extracting data\t10k-images-idx3-ubyte.gz
Extracting data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-05-14 15:51:08.543600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2018-05-14 15:51:08.544600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
0.9187
 
Process finished with exit code 0

可以看出只修改W对结果影响不大,如果设置只修改b

#导入Minst数据集
import input_data
mnist = input_data.read_data_sets("data",one_hot=True)
 
#导入tensorflow库
import tensorflow as tf
 
#输入变量,把28*28的图片变成一维数组(丢失结构信息)
x = tf.placeholder("float",[None,784])
 
#权重矩阵,把28*28=784的一维输入,变成0-9这10个数字的输出
w = tf.Variable(tf.zeros([784,10]))
#偏置
b = tf.Variable(tf.zeros([10]))
 
#核心运算,其实就是softmax(x*w+b)
y = tf.nn.softmax(tf.matmul(x,w) + b)
 
#这个是训练集的正确结果
y_ = tf.placeholder("float",[None,10])
 
#交叉熵,作为损失函数
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
 
#梯度下降算法,最小化交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy,var_list=[b])
 
#初始化,在run之前必须进行的
init = tf.initialize_all_variables()
#创建session以便运算
sess = tf.Session()
sess.run(init)
 
#迭代1000次
for i in range(1000):
 #获取训练数据集的图片输入和正确表示数字
 batch_xs, batch_ys = mnist.train.next_batch(100)
 #运行刚才建立的梯度下降算法,x赋值为图片输入,y_赋值为正确的表示数字
 sess.run(train_step,feed_dict = {x:batch_xs, y_: batch_ys})
 
#tf.argmax获取最大值的索引。比较运算后的结果和本身结果是否相同。
#这步的结果应该是[1,1,1,1,1,1,1,1,0,1...........1,1,0,1]这种形式。
#1代表正确,0代表错误
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
 
#tf.cast先将数据转换成float,防止求平均不准确。
#tf.reduce_mean由于只有一个参数,就是上面那个数组的平均值。
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
#输出
print(sess.run(accuracy,feed_dict={x:mnist.test.images,y_: mnist.test.labels}))

计算结果:

"C:\Program Files\Anaconda3\python.exe" D:/pycharmprogram/tensorflow_learn/softmax_learn/softmax_learn.py
Extracting data\train-images-idx3-ubyte.gz
Extracting data\train-labels-idx1-ubyte.gz
Extracting data\t10k-images-idx3-ubyte.gz
Extracting data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\util\tf_should_use.py:175: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2018-05-14 15:52:04.483600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2018-05-14 15:52:04.483600: W C:\tf_jenkins\home\workspace\rel-win\M\windows\PY\35\tensorflow\core\platform\cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
0.1135
 
Process finished with exit code 0

如果只更新b那么对效果影响很大。

以上这篇在Tensorflow中实现梯度下降法更新参数值就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python生成不重复随机值的方法
May 11 Python
python中pygame针对游戏窗口的显示方法实例分析(附源码)
Nov 11 Python
Python爬取APP下载链接的实现方法
Sep 30 Python
浅谈Python基础之I/O模型
May 11 Python
Python多进程库multiprocessing中进程池Pool类的使用详解
Nov 24 Python
Python+selenium实现自动循环扔QQ邮箱漂流瓶
May 29 Python
python 提取key 为中文的json 串方法
Dec 31 Python
Python3 单行多行万能正则匹配方法
Jan 07 Python
Pyinstaller打包.py生成.exe的方法和报错总结
Apr 02 Python
解决Python pip 自动更新升级失败的问题
Feb 21 Python
Django form表单与请求的生命周期步骤详解
Jun 07 Python
python 实现定时任务的四种方式
Apr 01 Python
Tensorflow实现部分参数梯度更新操作
Jan 23 #Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
使用tensorflow显示pb模型的所有网络结点方式
Jan 23 #Python
tensorflow 实现打印pb模型的所有节点
Jan 23 #Python
TensorFlow命名空间和TensorBoard图节点实例
Jan 23 #Python
tensorflow通过模型文件,使用tensorboard查看其模型图Graph方式
Jan 23 #Python
如何定义TensorFlow输入节点
Jan 23 #Python
You might like
DC四月将推出百页特刊漫画 纪念小丑诞生80周年
2020/04/09 欧美动漫
PHP json格式和js json格式 js跨域调用实现代码
2012/09/08 PHP
深入php多态的实现详解
2013/06/09 PHP
php修改上传图片尺寸的方法
2015/04/14 PHP
PHP使用栈解决约瑟夫环问题算法示例
2017/08/27 PHP
php提取微信账单的有效信息
2018/10/01 PHP
php测试kafka项目示例
2020/02/06 PHP
PHP中echo与print区别点整理
2021/03/09 PHP
js 判断浏览器使用的语言示例代码
2014/03/22 Javascript
使用cluster 将自己的Node服务器扩展为多线程服务器
2014/11/10 Javascript
js实现兼容性好的微软官网导航下拉菜单效果
2015/09/07 Javascript
JavaScript匿名函数之模仿块级作用域
2015/12/12 Javascript
微信小程序 页面传参实例详解
2016/11/16 Javascript
Angular使用$http.jsonp发送跨站请求的方法
2017/03/16 Javascript
JavaScript条件判断_动力节点Java学院整理
2017/06/26 Javascript
JS自定义滚动条效果简单实现代码
2020/10/27 Javascript
Vuejs 页面的区域化与组件封装的实现
2017/09/11 Javascript
如何编写一个完整的Angular4 FormText 组件
2017/11/18 Javascript
快速解决vue-cli在ie9+中无效的问题
2018/09/04 Javascript
webpack4实现不同的导出类型
2019/04/09 Javascript
nuxt踩坑之Vuex状态树的模块方式使用详解
2019/09/06 Javascript
Python smallseg分词用法实例分析
2015/05/28 Python
python中使用PIL制作并验证图片验证码
2018/03/15 Python
python scatter散点图用循环分类法加图例
2019/03/19 Python
Django Rest framework权限的详细用法
2019/07/25 Python
ipad上运行python的方法步骤
2019/10/12 Python
如何使用css3实现一个类在线直播的队列动画的示例代码
2020/06/17 HTML / CSS
植物选择:Botanic Choice
2017/02/15 全球购物
伦敦剧院门票:From The Box Office
2018/06/30 全球购物
环境科学专业大学生自荐信格式
2013/09/21 职场文书
六十大寿答谢词
2014/01/12 职场文书
实习生求职自荐信
2014/02/07 职场文书
2014年服务员个人工作总结
2014/12/23 职场文书
资料员岗位职责
2015/02/10 职场文书
2015年治庸问责工作总结
2015/07/27 职场文书
Linux中Nginx的防盗链和优化的实现代码
2021/06/20 Servers