tensorflow中的梯度求解及梯度裁剪操作


Posted in Python onMay 26, 2021

1. tensorflow中梯度求解的几种方式

1.1 tf.gradients

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name='gradients',
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None,
    unconnected_gradients=tf.UnconnectedGradients.NONE
)

计算ys关于xs的梯度,tf.gradients返回的结果是一个长度为len(xs)的tensor列表list,例如

tf.gradients(y, [x1, x2, x3]返回[dy/dx1, dy/dx2, dy/dx3]

当y与x无关时,即graph无x到y的路径, 则求y关于x的梯度时返回[None];参数stop_gradients指定的变量对当前梯度求解而言, 梯度求解将止于这些变量。

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) #梯度计算不再追溯a,b之前的变量

输出:

In: sess.run(g)

out:[1.0, 1.0]

如果不设置stop_gradients参数则反向传播梯度计算将追溯到最开始的值a,输出结果为:

In : sess.run(g)

Out: [3.0, 1.0]

1.2 optimizer.compute_gradients

compute_gradients(
    loss,
    var_list=None,
    gate_gradients=GATE_OP,
    aggregation_method=None,
    colocate_gradients_with_ops=False,
    grad_loss=None
)

optimizer.compute_gradients是tf.gradients的封装,作用相同,但是tfgradients只返回梯度,compute_gradients返回梯度和可导的变量;tf.compute_gradients是optimizer.minimize()的第一步,optimizer.compute_gradients返回一个[(gradient, variable),…]的元组列表,其中gradient是tensor。

直观上,optimizer.compute_gradients只比tf.gradients多了一个variable输出。

optimizer = tf.train.GradientDescentOptimizer(learning_rate = 1.0)
self.train_op = optimizer.minimize(self.cost)
sess.run([train_op], feed_dict={x:data, y:labels})

在这个过程中,调用minimize方法的时候,底层进行的工作包括:

(1) 使用tf.optimizer.compute_gradients计算trainable_variables 集合中所有参数的梯度

(2) 用optimizer.apply_gradients来更新计算得到的梯度对应的变量

上面代码等价于下面代码

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
grads_and_vars = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients(grads_and_vars)

1.3 tf.stop_gradient

tf.stop_gradient(
    input,
    name=None
)

tf.stop_gradient阻止input的变量参与梯度计算,即在梯度计算的过程中屏蔽input之前的graph。

返回:关于input的梯度

2. 梯度裁剪

如果我们希望对梯度进行截断,那么就要自己计算出梯度,然后进行clip,最后应用到变量上,代码如下所示,接下来我们一一介绍其中的主要步骤

#return a list of trainable variable in you model
params = tf.trainable_variables()

#create an optimizer
opt = tf.train.GradientDescentOptimizer(self.learning_rate)

#compute gradients for params
gradients = tf.gradients(loss, params)

#process gradients
clipped_gradients, norm = tf.clip_by_global_norm(gradients,max_gradient_norm)

train_op = opt.apply_gradients(zip(clipped_gradients, params)))

2.1 tf.clip_by_global_norm介绍

tf.clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None)

 

t_list 表示梯度张量

clip_norm是截取的比率

在应用这个函数之后,t_list[i]的更新公示变为:

global_norm = sqrt(sum(l2norm(t)**2 for t in t_list))
t_list[i] = t_list[i] * clip_norm / max(global_norm, clip_norm)

也就是分为两步:

(1) 计算所有梯度的平方和global_norm

(2) 如果梯度平方和 global_norm 超过我们指定的clip_norm,那么就对梯度进行缩放;否则就按照原本的计算结果

梯度裁剪实例2

loss = w*x*x
optimizer = tf.train.GradientDescentOptimizer(0.1)
grads_and_vars = optimizer.compute_gradients(loss,[w,x])
grads = tf.gradients(loss,[w,x])
# 修正梯度
for i,(gradient,var) in enumerate(grads_and_vars):
    if gradient is not None:
        grads_and_vars[i] = (tf.clip_by_norm(gradient,5),var)
train_op = optimizer.apply_gradients(grads_and_vars)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(grads_and_vars))
     # 梯度修正前[(9.0, 2.0), (12.0, 3.0)];梯度修正后 ,[(5.0, 2.0), (5.0, 3.0)]
    print(sess.run(grads))  #[9.0, 12.0],
    print(train_op)

补充:tensorflow框架中几种计算梯度的方式

1. tf.gradients

tf.gradients(
    ys,
    xs,
    grad_ys=None,
    name='gradients',
    colocate_gradients_with_ops=False,
    gate_gradients=False,
    aggregation_method=None,
    stop_gradients=None,
    unconnected_gradients=tf.UnconnectedGradients.NONE
)

计算ys关于xs的梯度,tf.gradients返回的结果是一个长度为len(xs)的Tensor列表list,每个张量为sum(dy/dx),即ys关于xs的导数。

例子:

tf.gradients(y, [x1, x2, x3]返回[dy/dx1, dy/dx2, dy/dx3]

当y与x无关时,即graph无x到y的路径, 则求y关于x的梯度时返回[None]

参数stop_gradients指定的变量对当前梯度求解而言, 梯度求解将止于这些变量。

实例:

a = tf.constant(0.)
b = 2 * a
g = tf.gradients(a + b, [a, b], stop_gradients=[a, b]) #梯度计算不再追溯a,b之前的变量

输出:

In: sess.run(g)

out:[1.0, 1.0]

如果不设置stop_gradients参数则反向传播梯度计算将追溯到最开始的值a,输出结果为:

In : sess.run(g)

Out: [3.0, 1.0]

2. optimizer.compute_gradients

compute_gradients(
    loss,
    var_list=None,
    gate_gradients=GATE_OP,
    aggregation_method=None,
    colocate_gradients_with_ops=False,
    grad_loss=None
)

optimizer.compute_gradients是tf.gradients的封装1.

是optimizer.minimize()的第一步,返回(gradient, variable)的列表,其中gradient是tensor。

直观上,optimizer.compute_gradients只比tf.gradients多了一个variable输出。

3. tf.stop_gradient

tf.stop_gradient(
    input,
    name=None
)

tf.stop_gradient阻止input的变量参与梯度计算,即在梯度计算的过程中屏蔽input之前的graph。

返回:关于input的梯度

应用:

1、EM算法,其中M步骤不应涉及通过E步骤的输出的反向传播。

2、Boltzmann机器的对比散度训练,在区分能量函数时,训练不得反向传播通过模型生成样本的图形。

3、对抗性训练,通过对抗性示例生成过程不会发生反向训练。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详细解析Python中__init__()方法的高级应用
May 11 Python
Python新手入门最容易犯的错误总结
Apr 24 Python
利用Python实现Windows下的鼠标键盘模拟的实例代码
Jul 13 Python
python实现俄罗斯方块游戏
Mar 25 Python
Python中asyncio与aiohttp入门教程
Oct 16 Python
浅谈Python批处理文件夹中的txt文件
Mar 11 Python
深入浅析Python 函数注解与匿名函数
Feb 24 Python
Python进程Multiprocessing模块原理解析
Feb 28 Python
使用Matplotlib绘制不同颜色的带箭头的线实例
Apr 17 Python
Python操控mysql批量插入数据的实现方法
Oct 27 Python
python b站视频下载的五种版本
May 27 Python
Python+Matplotlib图像上指定坐标的位置添加文本标签与注释
Apr 11 Python
python numpy中multiply与*及matul 的区别说明
May 26 #Python
python文本处理的方案(结巴分词并去除符号)
Django操作cookie的实现
May 26 #Python
pandas中DataFrame检测重复值的实现
python 中的@运算符使用
May 26 #Python
Python 实现定积分与二重定积分的操作
May 26 #Python
python 解决微分方程的操作(数值解法)
You might like
php 文件上传代码(限制jpg文件)
2010/01/05 PHP
PHP源码之explode使用说明
2011/08/05 PHP
PHP判断用户是否已经登录(跳转到不同页面或者执行不同动作)
2016/09/22 PHP
PHP魔术方法之__call与__callStatic使用方法
2017/07/23 PHP
PHP使用openssl扩展实现加解密方法示例
2020/02/20 PHP
5款Javascript颜色选择器
2009/10/25 Javascript
js修改原型的属性使用介绍
2014/01/26 Javascript
当前流行的JavaScript代码风格指南
2014/09/10 Javascript
javascript制作坦克大战全纪录(1)
2014/11/27 Javascript
jQuery实现仿QQ空间装扮预览图片的鼠标提示效果代码
2015/10/30 Javascript
AngularJS数据源的多种获取方式汇总
2016/02/02 Javascript
基于javascript实现九宫格大转盘效果
2020/05/28 Javascript
对象转换为原始值的实现方法
2016/06/06 Javascript
浅谈jQuery添加的HTML,JS失效的问题
2016/10/05 Javascript
Node.JS利用PhantomJs抓取网页入门教程
2017/05/19 Javascript
利用js将ajax获取到的后台数据动态加载至网页中的方法
2018/08/08 Javascript
Angular项目如何升级至Angular6步骤全纪录
2018/09/03 Javascript
vuex直接赋值的三种方法总结
2018/09/16 Javascript
Bootstrap 实现表格样式、表单布局的实例代码
2018/12/09 Javascript
Python中内置的日志模块logging用法详解
2016/07/12 Python
python中异常捕获方法详解
2017/03/03 Python
Python实现的微信公众号群发图片与文本消息功能实例详解
2017/06/30 Python
浅析python的Lambda表达式
2019/02/27 Python
python实现图片中文字分割效果
2019/07/22 Python
Android面试题附答案
2014/12/08 面试题
介绍一下linux的文件权限
2014/07/20 面试题
商场消防演习方案
2014/02/12 职场文书
一位农村小子的自荐信
2014/04/07 职场文书
社会实践活动总结报告
2014/04/29 职场文书
学校就业推荐信范文
2014/05/19 职场文书
好的促销活动方案
2014/08/21 职场文书
2015年幼儿园保育员工作总结
2015/04/23 职场文书
单位证明范文
2015/06/18 职场文书
学校财务管理制度
2015/08/04 职场文书
创业计划书之农家乐
2019/10/09 职场文书
JavaScript中的宏任务和微任务详情
2021/11/27 Javascript