Tensorflow之梯度裁剪的实现示例


Posted in Python onMarch 08, 2020

tensorflow中的梯度计算和更新

为了解决深度学习中常见的梯度消失(gradient explosion)和梯度爆炸(gradients vanishing)问题,tensorflow中所有的优化器tf.train.xxxOptimizer都有两个方法:

  1. compute_gradients
  2. apply_gradients

compute_gradients

对于compute_gradients方法,计算var_list中参数的梯度,使得loss变小。默认情况下,var_list为GraphKeys.TRAINABLE_VARIABLES中的所有参数。

compute_gradients方法返回由多个(gradients, variable)二元组组成的列表。

compute_gradients(
  loss,
  var_list=None,
  gate_gradients=GATE_OP,
  aggregation_method=None,
  colocate_gradients_with_ops=False,
  grad_loss=None
)

apply_gradients

对于apply_gradients方法,根据compute_gradients的返回结果对参数进行更新

apply_gradients(
  grads_and_vars,
  global_step=None,
  name=None
)

梯度裁剪(Gradient Clipping)

tensorflow中裁剪梯度的几种方式

方法一tf.clip_by_value

def clip_by_value(t, clip_value_min, clip_value_max,
         name=None):

其中,t为一个张量,clip_by_value返回一个与t的type相同、shape相同的张量,但是新tensor中的值被裁剪到了clip_value_min和clip_value_max之间。

方法二:tf.clip_by_global_norm

def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):

其中,t_list为A tuple or list of mixed Tensors, IndexedSlices, or None。clip_norm为clipping ratio,use_norm指定global_norm,如果use_norm为None,则按global_norm = sqrt(sum([l2norm(t)**2 for t in t_list]))计算global_norm。

最终,梯度的裁剪方式为

Tensorflow之梯度裁剪的实现示例

可知,如果clip_norm > global_norm, 则不对梯度进行裁剪,否则对梯度进行缩放。

scale = clip_norm * math_ops.minimum(
    1.0 / use_norm,
    constant_op.constant(1.0, dtype=use_norm.dtype) / clip_norm)

方法的返回值为裁剪后的梯度列表list_clipped和global_norm

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate)
gradients, v = zip(*optimizer.compute_gradients(loss))
gradients, _ = tf.clip_by_global_norm(gradients, grad_clip)
updates = optimizer.apply_gradients(zip(gradients, v),global_step=global_step)

方法三tf.clip_by_average_norm

def clip_by_average_norm(t, clip_norm, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式如下,

Tensorflow之梯度裁剪的实现示例

其中,avg_norm=l2norm_avg(t)

方法四:tf.clip_by_norm

def clip_by_norm(t, clip_norm, axes=None, name=None):

t为张量,clip_norm为maximum clipping value

裁剪方式为

Tensorflow之梯度裁剪的实现示例

示例代码

optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.5)
grads = optimizer.compute_gradients(cost)
for i, (g, v) in enumerate(grads):
  if g is not None:
    grads[i] = (tf.clip_by_norm(g, 5), v) # clip gradients
train_op = optimizer.apply_gradients(grads)

注意到,clip_by_value、clib_by-avg_norm和clip_by_norm都是针对于单个张量的,而clip_by_global_norm可用于多个张量组成的列表。

到此这篇关于Tensorflow之梯度裁剪的实现示例的文章就介绍到这了,更多相关Tensorflow 梯度裁剪内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
使用Python中PDB模块中的命令来调试Python代码的教程
Mar 30 Python
Python和Perl绘制中国北京跑步地图的方法
Mar 03 Python
利用Python自带PIL库扩展图片大小给图片加文字描述的方法示例
Aug 08 Python
轻量级的Web框架Flask 中模块化应用的实现
Sep 11 Python
利用python将图片转换成excel文档格式
Dec 30 Python
python 循环读取txt文档 并转换成csv的方法
Oct 26 Python
Python3中的bytes和str类型详解
May 02 Python
基于django ManyToMany 使用的注意事项详解
Aug 09 Python
手把手教你pycharm专业版安装破解教程(linux版)
Sep 26 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 Python
Django使用Celery加redis执行异步任务的实例内容
Feb 20 Python
python爬虫scrapy图书分类实例讲解
Nov 23 Python
Django自定义全局403、404、500错误页面的示例代码
Mar 08 #Python
Django 自定义404 500等错误页面的实现
Mar 08 #Python
Python loguru日志库之高效输出控制台日志和日志记录
Mar 07 #Python
Centos7下源码安装Python3 及shell 脚本自动安装Python3的教程
Mar 07 #Python
Django接收照片储存文件的实例代码
Mar 07 #Python
Python实现对adb命令封装
Mar 06 #Python
对Python中 \r, \n, \r\n的彻底理解
Mar 06 #Python
You might like
东芝TOSHIBA RP-F11电路分析
2021/03/02 无线电
用PHP生成静态HTML速度快类库
2007/03/18 PHP
让php处理图片变得简单 基于gb库的图片处理类附实例代码下载
2011/05/17 PHP
关于PHP递归算法和应用方法介绍
2013/04/15 PHP
php延迟静态绑定实例分析
2015/02/08 PHP
javascript中CheckBox全选终极方案
2015/05/20 Javascript
JavaScript如何禁止Backspace键
2015/12/02 Javascript
理解javascript定时器中的setTimeout与setInterval
2016/02/23 Javascript
基于JQuery实现的跑马灯效果(文字无缝向上翻动)
2016/12/02 Javascript
原生js实现图片放大缩小计时器效果
2017/01/20 Javascript
vue.js学习之UI组件开发教程
2017/07/03 Javascript
Vue中使用Sortable的示例代码
2018/04/07 Javascript
使用layui的layer组件做弹出层的例子
2019/09/27 Javascript
jQuery实现全选、反选和不选功能的方法详解
2019/12/04 jQuery
微信小程序实现吸顶特效
2020/01/08 Javascript
JavaScript遍历数组的方法代码实例
2020/01/14 Javascript
vue修改Element的el-table样式的4种方法
2020/09/17 Javascript
Vant 中的Toast设置全局的延迟时间操作
2020/11/04 Javascript
python中wx将图标显示在右下角的脚本代码
2013/03/08 Python
python发送伪造的arp请求
2014/01/09 Python
介绍Python的Urllib库的一些高级用法
2015/04/30 Python
python的变量与赋值详细分析
2017/11/08 Python
Python将多个list合并为1个list的方法
2018/06/27 Python
PyTorch加载自己的数据集实例详解
2020/03/18 Python
python如何实现读取并显示图片(不需要图形界面)
2020/07/08 Python
python工具快速为音视频自动生成字幕(使用说明)
2021/01/27 Python
使用css实现android系统的loading加载动画
2019/07/25 HTML / CSS
CSS3实现歌词进度文字颜色填充变化动态效果的思路详解
2020/06/02 HTML / CSS
西班牙英格列斯百货官网:El Corte Inglés
2016/09/25 全球购物
工伤事故赔偿协议书
2014/10/27 职场文书
保卫工作个人总结
2015/03/03 职场文书
幼儿园班级工作总结2015
2015/05/25 职场文书
古诗之爱国古诗5首
2019/09/20 职场文书
原生Javascript+HTML5一步步实现拖拽排序
2021/06/12 Javascript
Java中CyclicBarrier和CountDownLatch的用法与区别
2021/08/23 Java/Android
使用CSS设置滚动条样式
2022/01/18 HTML / CSS