Tensorflow中k.gradients()和tf.stop_gradient()用法说明


Posted in Python onJune 10, 2020

上周在实验室开荒某个代码,看到中间这么一段,对Tensorflow中的stop_gradient()还不熟悉,特此周末进行重新并总结。

y = xx + K.stop_gradient(rounded - xx)

这代码最终调用位置在tensoflow.python.ops.gen_array_ops.stop_gradient(input, name=None),关于这段代码为什么这样写的意义在文末给出。

【stop_gradient()意义】

用stop_gradient生成损失函数w.r.t.的梯度。

【tf.gradients()理解】

tf中我们只需要设计我们自己的函数,tf提供提供强大的自动计算函数梯度方法,tf.gradients()。

tf.gradients(
 ys,
 xs,
 grad_ys=None,
 name='gradients',
 colocate_gradients_with_ops=False,
 gate_gradients=False,
 aggregation_method=None,
 stop_gradients=None,
 unconnected_gradients=tf.UnconnectedGradients.NONE
)

gradients() adds ops to the graph to output the derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys.

1、tf.gradients()实现ys对xs的求导

2、ys和xs可以是Tensor或者list包含的Tensor

3、求导返回值是一个list,list的长度等于len(xs)

eg.假设返回值是[grad1, grad2, grad3],ys=[y1, y2],xs=[x1, x2, x3]。则计算过程为:

Tensorflow中k.gradients()和tf.stop_gradient()用法说明

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
gradients_node = tf.gradients(loss_op, w)
 
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([train_op, gradients_node, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

自定义梯度和更新函数

import numpy as np
import tensorflow as tf
 
#构造数据集
x_pure = np.random.randint(-10, 100, 32)
x_train = x_pure + np.random.randn(32) / 32
y_train = 3 * x_pure + 2 + np.random.randn(32) / 32
 
x_input = tf.placeholder(tf.float32, name='x_input')
y_input = tf.placeholder(tf.float32, name='y_input')
w = tf.Variable(2.0, name='weight')
b = tf.Variable(1.0, name='biases')
y = tf.add(tf.multiply(x_input, w), b)
 
loss_op = tf.reduce_sum(tf.pow(y_input - y, 2)) / (2 * 32)
# train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss_op)
 
#自定义权重更新
grad_w, grad_b = tf.gradients(loss_op, [w, b])
new_w = w.assign(w - 0.01 * grad_w)
new_b = b.assign(b - 0.01 * grad_b)
 
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
 
for i in range(20):
 _, gradients, loss = sess.run([new_w, new_b, loss_op], feed_dict={x_input: x_train[i], y_input: y_train[i]})
 print("epoch: {} \t loss: {} \t gradients: {}".format(i, loss, gradients))
sess.close()

【tf.stop_gradient()理解】

在tf.gradients()参数中存在stop_gradients,这是一个List,list中的元素是tensorflow graph中的op,一旦进入这个list,将不会被计算梯度,更重要的是,在该op之后的BP计算都不会运行。

import numpy as np
import tensorflow as tf
 
a = tf.constant(0.)
b = 2 * a
c = a + b
g = tf.gradients(c, [a, b])
 
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(g))
 
#输出[3.0, 1.0]

在用一个stop_gradient()的例子

import tensorflow as tf
 
#实验一
w1 = tf.Variable(2.0)
w2 = tf.Variable(2.0)
a = tf.multiply(w1, 3.0)
a_stoped = tf.stop_gradient(a)
 
# b=w1*3.0*w2
b = tf.multiply(a_stoped, w2)
gradients = tf.gradients(b, xs=[w1, w2])
print(gradients)
#输出[None, <tf.Tensor 'gradients/Mul_1_grad/Reshape_1:0' shape=() dtype=float32>]
 
#实验二
a = tf.Variable(1.0)
b = tf.Variable(1.0)
c = tf.add(a, b)
c_stoped = tf.stop_gradient(c)
d = tf.add(a, b)
e = tf.add(c_stoped, d)
gradients = tf.gradients(e, xs=[a, b])
with tf.Session() as sess:
 tf.global_variables_initializer().run()
 print(sess.run(gradients))
 
#因为梯度从另外地方传回,所以输出 [1.0, 1.0]

【答案】

开始提出的问题,为什么存在那段代码:

t = g(x)

y = t + tf.stop_gradient(f(x) - t)

这里,我们本来的前向传递函数是XX,但是想要在反向时传递的函数是g(x),因为在前向过程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向传递;而在反向过程中,因为tf.stop_gradient()的作用,使得f(x)-t的梯度变为了0,从而只剩下g(x)在反向传递。

以上这篇Tensorflow中k.gradients()和tf.stop_gradient()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python下载Bing图片(代码)
Nov 07 Python
Python遍历指定文件及文件夹的方法
May 09 Python
python daemon守护进程实现
Aug 27 Python
Python中的字符串操作和编码Unicode详解
Jan 18 Python
Python中shutil模块的学习笔记教程
Apr 04 Python
Python实现发送QQ邮件的封装
Jul 14 Python
Django models.py应用实现过程详解
Jul 29 Python
基于tensorflow指定GPU运行及GPU资源分配的几种方式小结
Feb 03 Python
浅谈SciPy中的optimize.minimize实现受限优化问题
Feb 29 Python
python pandas dataframe 去重函数的具体使用
Jul 20 Python
python mongo 向数据中的数组类型新增数据操作
Dec 05 Python
如何用Python搭建gRPC服务
Jun 30 Python
PySide2出现“ImportError: DLL load failed: 找不到指定的模块”的问题及解决方法
Jun 10 #Python
浅谈Python中的字符串
Jun 10 #Python
Keras 使用 Lambda层详解
Jun 10 #Python
keras打印loss对权重的导数方式
Jun 10 #Python
Python xlrd模块导入过程及常用操作
Jun 10 #Python
keras-siamese用自己的数据集实现详解
Jun 10 #Python
python实现mean-shift聚类算法
Jun 10 #Python
You might like
php安全配置 如何配置使其更安全
2011/12/16 PHP
php上传文件中文文件名乱码的解决方法
2013/11/01 PHP
PHP基于curl后台远程登录正方教务系统的方法
2016/10/14 PHP
php微信公众平台开发(四)回复功能开发
2016/12/06 PHP
PHP空值检测函数与方法汇总
2017/11/19 PHP
PHP7数组的底层实现示例
2019/08/25 PHP
javascript脚本编程解决考试分数统计问题
2008/10/18 Javascript
详细分析JavaScript函数定义
2015/07/16 Javascript
详解js中class的多种函数封装方法
2016/01/03 Javascript
js仿百度登录页实现拖动窗口效果
2016/03/11 Javascript
jQuery文本框得到与失去焦点动态改变样式效果
2016/09/08 Javascript
jQuery居中元素scrollleft计算方法示例
2017/01/16 Javascript
discuz表情的JS提取方法分析
2017/03/22 Javascript
使用ES6语法重构React代码详解
2017/05/09 Javascript
JavaScript使用readAsDataURL读取图像文件
2017/05/10 Javascript
利用nginx + node在阿里云部署https的步骤详解
2017/12/19 Javascript
vue请求本地自己编写的json文件的方法
2019/04/25 Javascript
解决layui表格的表头不滚动的问题
2019/09/04 Javascript
Python实现LRU算法的2种方法
2015/06/24 Python
Python中atexit模块的基本使用示例
2015/07/08 Python
win10环境下python3.5安装步骤图文教程
2017/02/03 Python
Python3 mmap内存映射文件示例解析
2020/03/23 Python
使用pytorch实现论文中的unet网络
2020/06/24 Python
使用ITK-SNAP进行抠图操作并保存mask的实例
2020/07/01 Python
CSS3中各种颜色属性的使用教程
2016/05/17 HTML / CSS
Stefania Mode美国:奢华设计师和时尚服装
2018/01/07 全球购物
英国100%防污和防水的靴子:Muck Boot Company
2020/09/08 全球购物
个人找工作的自我评价
2013/10/17 职场文书
校友会欢迎辞
2014/01/13 职场文书
给面试官的感谢信
2014/02/01 职场文书
地球一小时宣传标语
2014/06/24 职场文书
意外死亡赔偿协议书
2014/10/14 职场文书
2014普法依法治理工作总结
2014/12/18 职场文书
检讨书模板
2015/01/29 职场文书
求职导师推荐信范文
2015/03/27 职场文书
Python爬虫入门案例之回车桌面壁纸网美女图片采集
2021/10/16 Python