Tensorflow之梯度裁剪的实现示例


Posted in Python onMarch 08, 2020

tensorflow中的梯度计算和更新

为了解决深度学习中常见的梯度消失(gradient explosion)和梯度爆炸(gradients vanishing)问题,tensorflow中所有的优化器tf.train.xxxOptimizer都有两个方法:

  1. compute_gradients
  2. apply_gradients

compute_gradients

对于compute_gradients方法,计算var_list中参数的梯度,使得loss变小。默认情况下,var_list为GraphKeys.TRAINABLE_VARIABLES中的所有参数。

compute_gradients方法返回由多个(gradients, variable)二元组组成的列表。

compute_gradients(
  loss,
  var_list=None,
  gate_gradients=GATE_OP,
  aggregation_method=None,
  colocate_gradients_with_ops=False,
  grad_loss=None
)

apply_gradients

对于apply_gradients方法,根据compute_gradients的返回结果对参数进行更新

apply_gradients(
  grads_and_vars,
  global_step=None,
  name=None
)

梯度裁剪(Gradient Clipping)

tensorflow中裁剪梯度的几种方式

方法一tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,
         name=None):

其中,t为一个张量,clip_by_value返回一个与t的type相同、shape相同的张量,但是新tensor中的值被裁剪到了clip_value_min和clip_value_max之间。

方法二:tf.clip_by_global_norm

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

其中,t_list为A tuple or list of mixed Tensors, IndexedSlices, or None。clip_norm为clipping ratio,use_norm指定global_norm,如果use_norm为None,则按global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))计算global_norm。

最终,梯度的裁剪方式为

Tensorflow之梯度裁剪的实现示例

可知,如果clip_norm > global_norm, 则不对梯度进行裁剪,否则对梯度进行缩放。

scale = clip_norm * math_ops.minimum(
    1.0 / use_norm,
    constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)

方法的返回值为裁剪后的梯度列表list_clipped和global_norm

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v),global_step=global_step)

方法三tf.clip_by_average_norm

def clip_by_average_norm(t, clip_norm, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式如下,

Tensorflow之梯度裁剪的实现示例

其中,avg_norm=l2norm_avg(t)

方法四:tf.clip_by_norm

def clip_by_norm(t, clip_norm, axes=None, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式为

Tensorflow之梯度裁剪的实现示例

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads = optimizer.compute_gradients(cost)
for i, (g, v) in enumerate(grads):
  if g is not None:
    grads[i] = (tf.clip_by_norm(g, 5), v) # clip gradients
train_op = optimizer.apply_gradients(grads)

注意到,clip_by_value、clib_by-avg_norm和clip_by_norm都是针对于单个张量的,而clip_by_global_norm可用于多个张量组成的列表。

到此这篇关于Tensorflow之梯度裁剪的实现示例的文章就介绍到这了,更多相关Tensorflow 梯度裁剪内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
编写Python脚本来实现最简单的FTP下载的教程
May 04 Python
用python写的一个wordpress的采集程序
Feb 27 Python
python实现定时自动备份文件到其他主机的实例代码
Feb 23 Python
详解python3中的真值测试
Aug 13 Python
python scp 批量同步文件的实现方法
Jan 03 Python
python自动化UI工具发送QQ消息的实例
Aug 27 Python
python 微信好友特征数据分析及可视化
Jan 07 Python
Python&&GDAL实现NDVI的计算方式
Jan 09 Python
Keras实现DenseNet结构操作
Jul 06 Python
pycharm中使用request和Pytest进行接口测试的方法
Jul 31 Python
python lambda的使用详解
Feb 26 Python
Python基础之条件语句详解
Jun 16 Python
Django自定义全局403、404、500错误页面的示例代码
Mar 08 #Python
Django 自定义404 500等错误页面的实现
Mar 08 #Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 #Python
Centos7下源码安装Python3 及shell 脚本自动安装Python3的教程
Mar 07 #Python
Django接收照片储存文件的实例代码
Mar 07 #Python
Python实现对adb命令封装
Mar 06 #Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 #Python
You might like
Eclipse的PHP插件PHPEclipse安装和使用
2014/07/20 PHP
基于CI框架的微信网页授权库示例
2016/11/25 PHP
Javascript技术技巧大全(五)
2007/01/22 Javascript
IE innerHTML,outerHTML所引起的问题
2009/06/04 Javascript
兼容IE与firefox火狐的回车事件(js与jquery)
2010/10/20 Javascript
Javascript创建自定义对象 创建Object实例添加属性和方法
2012/06/04 Javascript
禁止页面刷新让F5快捷键及右键都无效
2014/01/22 Javascript
javascript+ajax实现产品页面加载信息
2015/07/09 Javascript
JavaScript判断数字是否为质数的方法汇总
2016/06/02 Javascript
jquery实现图片列表鼠标移入微动
2016/12/01 Javascript
利用Vue v-model实现一个自定义的表单组件
2017/04/27 Javascript
利用node.js制作命令行工具方法教程(一)
2017/06/22 Javascript
Angular排序实例详解
2017/06/28 Javascript
vue-cli3.0 特性解读
2018/04/22 Javascript
JavaScript中toLocaleString()和toString()的区别实例分析
2018/08/14 Javascript
Vue.js获取被选择的option的value和text值方法
2018/08/24 Javascript
用node撸一个监测复联4开售短信提醒的实现代码
2019/04/10 Javascript
在Vue中使用this.$store或者是$route一直报错的解决
2019/11/08 Javascript
JS+HTML实现自定义上传图片按钮并显示图片功能的方法分析
2020/02/12 Javascript
js实现随机圆与矩形功能
2020/10/29 Javascript
python创建和使用字典实例详解
2013/11/01 Python
Python 比较文本相似性的方法(difflib,Levenshtein)
2018/10/15 Python
Python爬虫设置代理IP(图文)
2018/12/23 Python
Python shelve模块实现解析
2019/08/28 Python
北美三大旅游网站之一:Travelocity加拿大
2016/08/20 全球购物
国际领先的学术出版商:Springer
2017/01/11 全球购物
Sofft鞋官网:世界知名鞋类品牌
2017/03/28 全球购物
"火柴棍式"程序员面试题
2014/03/16 面试题
Shell脚本如何向终端输出信息
2014/04/25 面试题
中医临床专业自我鉴定范文
2014/01/15 职场文书
商场消防演习方案
2014/02/12 职场文书
《圆明园的毁灭》教学反思
2014/02/28 职场文书
维稳工作承诺书
2015/01/20 职场文书
父亲节活动总结
2015/02/12 职场文书
高中生综合素质评价范文
2015/08/18 职场文书
Python实现byte转integer
2021/06/03 Python