基于TensorFlow中自定义梯度的2种方式


Posted in Python onFebruary 04, 2020

前言

在深度学习中,有时候我们需要对某些节点的梯度进行一些定制,特别是该节点操作不可导(比如阶梯除法如 基于TensorFlow中自定义梯度的2种方式 ),如果实在需要对这个节点进行操作,而且希望其可以反向传播,那么就需要对其进行自定义反向传播时的梯度。在有些场景,如[2]中介绍到的梯度反转(gradient inverse)中,就必须在某层节点对反向传播的梯度进行反转,也就是需要更改正常的梯度传播过程,如下图的 基于TensorFlow中自定义梯度的2种方式 所示。

基于TensorFlow中自定义梯度的2种方式

在tensorflow中有若干可以实现定制梯度的方法,这里介绍两种。

1. 重写梯度法

重写梯度法指的是通过tensorflow自带的机制,将某个节点的梯度重写(override),这种方法的适用性最广。我们这里举个例子[3].

符号函数的前向传播采用的是阶跃函数y=sign(x) y = \rm{sign}(x)y=sign(x),如下图所示,我们知道阶跃函数不是连续可导的,因此我们在反向传播时,将其替代为一个可以连续求导的函数y=Htanh(x) y = \rm{Htanh(x)}y=Htanh(x),于是梯度就是大于1和小于-1时为0,在-1和1之间时是1。

基于TensorFlow中自定义梯度的2种方式

使用重写梯度的方法如下,主要是涉及到tf.RegisterGradient()和tf.get_default_graph().gradient_override_map(),前者注册新的梯度,后者重写图中具有名字name='Sign'的操作节点的梯度,用在新注册的QuantizeGrad替代。

#使用修饰器,建立梯度反向传播函数。其中op.input包含输入值、输出值,grad包含上层传来的梯度
@tf.RegisterGradient("QuantizeGrad")
def sign_grad(op, grad):
 input = op.inputs[0] # 取出当前的输入
 cond = (input>=-1)&(input<=1) # 大于1或者小于-1的值的位置
 zeros = tf.zeros_like(grad) # 定义出0矩阵用于掩膜
 return tf.where(cond, grad, zeros) 
 # 将大于1或者小于-1的上一层的梯度置为0
 
#使用with上下文管理器覆盖原始的sign梯度函数
def binary(input):
 x = input
 with tf.get_default_graph().gradient_override_map({"Sign":'QuantizeGrad'}):
 #重写梯度
  x = tf.sign(x)
 return x
 
#使用
x = binary(x)

其中的def sign_grad(op, grad):是注册新的梯度的套路,其中的op是当前操作的输入值/张量等,而grad指的是从反向而言的上一层的梯度。

通常来说,在tensorflow中自定义梯度,函数tf.identity()是很重要的,其API手册如下:

tf.identity(
 input,
 name=None
)

其会返回一个形状和内容都和输入完全一样的输出,但是你可以自定义其反向传播时的梯度,因此在梯度反转等操作中特别有用。

这里再举个反向梯度[2]的例子,也就是梯度为 基于TensorFlow中自定义梯度的2种方式 而不是 基于TensorFlow中自定义梯度的2种方式

import tensorflow as tf
x1 = tf.Variable(1)
x2 = tf.Variable(3)
x3 = tf.Variable(6)
@tf.RegisterGradient('CustomGrad')
def CustomGrad(op, grad):
#  tf.Print(grad)
 return -grad
 
g = tf.get_default_graph()
oo = x1+x2
with g.gradient_override_map({"Identity": "CustomGrad"}):
 output = tf.identity(oo)
grad_1 = tf.gradients(output, oo)
with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 print(sess.run(grad_1))

因为-grad,所以这里的梯度输出是[-1]而不是[1]。有一个我们需要注意的是,在自定义函数def CustomGrad()中,返回的值得是一个张量,而不能返回一个参数,比如return 0,这样会报错,如:

AttributeError: 'int' object has no attribute 'name'

显然,这是因为tensorflow的内部操作需要取返回值的名字而int类型没有名字。

PS:def CustomGrad()这个函数签名是随便你取的。

2. stop_gradient法

对于自定义梯度,还有一种比较简洁的操作,就是利用tf.stop_gradient()函数,我们看下例子[1]:

t = g(x)
y = t + tf.stop_gradient(f(x) - t)

这里,我们本来的前向传递函数是f(x),但是想要在反向时传递的函数是g(x),因为在前向过程中,tf.stop_gradient()不起作用,因此+t和-t抵消掉了,只剩下f(x)前向传递;而在反向过程中,因为tf.stop_gradient()的作用,使得f(x)-t的梯度变为了0,从而只剩下g(x)在反向传递。

我们看下完整的例子:

import tensorflow as tf

x1 = tf.Variable(1)
x2 = tf.Variable(3)
x3 = tf.Variable(6)

f = x1+x2*x3
t = -f

y1 = t + tf.stop_gradient(f-t)
y2 = f

grad_1 = tf.gradients(y1, x1)
grad_2 = tf.gradients(y2, x1)
with tf.Session(config=config) as sess:
 sess.run(tf.global_variables_initializer())

 print(sess.run(grad_1))
 print(sess.run(grad_2))

第一个输出为[-1],第二个输出为[1],显然也实现了梯度的反转。

以上这篇基于TensorFlow中自定义梯度的2种方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过ssh-powershell监控windows的方法
Jun 02 Python
通过实例浅析Python对比C语言的编程思想差异
Aug 30 Python
Python文件常见操作实例分析【读写、遍历】
Dec 10 Python
Python matplotlib绘制饼状图功能示例
Sep 10 Python
Python进程间通信 multiProcessing Queue队列实现详解
Sep 23 Python
python multiprocessing多进程变量共享与加锁的实现
Oct 02 Python
Python中__repr__和__str__区别详解
Nov 07 Python
Python实现在Windows平台修改文件属性
Mar 05 Python
简单了解python调用其他脚本方法实例
Mar 26 Python
读取nii或nii.gz文件中的信息即输出图像操作
Jul 01 Python
详解Python IO编程
Jul 24 Python
详解python中的闭包
Sep 07 Python
tensorflow 查看梯度方式
Feb 04 #Python
opencv python图像梯度实例详解
Feb 04 #Python
TensorFlow设置日志级别的几种方式小结
Feb 04 #Python
Python 实现加密过的PDF文件转WORD格式
Feb 04 #Python
解决tensorflow打印tensor有省略号的问题
Feb 04 #Python
对Tensorflow中tensorboard日志的生成与显示详解
Feb 04 #Python
在 Python 中接管键盘中断信号的实现方法
Feb 04 #Python
You might like
php,ajax实现分页
2008/03/27 PHP
Smarty变量调节器失效的解决办法
2014/08/20 PHP
jquery不支持toggle()高(新)版本的问题解决
2016/09/24 PHP
PHP内存溢出优化代码详解
2021/02/26 PHP
文本链接逐个出现的js脚本
2007/12/12 Javascript
IE6 hack for js 集锦
2014/09/23 Javascript
DOM基础教程之事件类型
2015/01/20 Javascript
Jquery Easyui日历组件Calender使用详解(23)
2016/12/18 Javascript
JS实现的自动打字效果示例
2017/03/10 Javascript
微信小程序实现带刻度尺滑块功能
2017/03/29 Javascript
详解Vue中使用v-for语句抛出错误的解决方案
2017/05/04 Javascript
基于JavaScript实现飘落星星特效
2017/08/10 Javascript
浏览器调试动态js脚本的方法(图解)
2018/01/19 Javascript
基于openlayers4实现点的扩散效果
2020/08/17 Javascript
微信小程序实现人脸识别登陆的示例代码
2019/04/02 Javascript
Vue使用自定义指令实现拖拽行为实例分析
2020/06/06 Javascript
解决vscode进行vue格式化,会自动补分号和双引号的问题
2020/10/26 Javascript
JavaScript实现鼠标经过表格某行时此行变色
2020/11/20 Javascript
Javascript生成器(Generator)的介绍与使用
2021/01/31 Javascript
python做量化投资系列之比特币初始配置
2018/01/23 Python
Python_查看sqlite3表结构,查询语句的示例代码
2019/07/17 Python
详解用python生成随机数的几种方法
2019/08/04 Python
python动态视频下载器的实现方法
2019/09/16 Python
selenium中get_cookies()和add_cookie()的用法详解
2020/01/06 Python
CSS3 box-sizing属性详解
2016/11/15 HTML / CSS
荣耀俄罗斯官网:HONOR俄罗斯
2020/10/31 全球购物
业务经理的岗位职责
2013/11/16 职场文书
写给妈妈的道歉信
2014/01/11 职场文书
中层干部竞争上岗演讲稿
2014/01/13 职场文书
学生出入校管理制度
2014/01/16 职场文书
超市国庆节促销方案
2014/02/20 职场文书
小学四年级学生评语
2014/12/26 职场文书
运动会表扬稿范文
2015/05/05 职场文书
大学生社会实践感想
2015/08/11 职场文书
小学五年级班主任工作经验交流材料
2015/11/02 职场文书
Tomcat用户管理的优化配置详解
2022/03/31 Servers