在pytorch中对非叶节点的变量计算梯度实例


Posted in Python onJanuary 10, 2020

在pytorch中一般只对叶节点进行梯度计算,也就是下图中的d,e节点,而对非叶节点,也即是c,b节点则没有显式地去保留其中间计算过程中的梯度(因为一般来说只有叶节点才需要去更新),这样可以节省很大部分的显存,但是在调试过程中,有时候我们需要对中间变量梯度进行监控,以确保网络的有效性,这个时候我们需要打印出非叶节点的梯度,为了实现这个目的,我们可以通过两种手段进行。

在pytorch中对非叶节点的变量计算梯度实例

注册hook函数

Tensor.register_hook[2] 可以注册一个反向梯度传导时的hook函数,这个hook函数将会在每次计算 关于该张量 在pytorch中对非叶节点的变量计算梯度实例 的时候 被调用,经常用于调试的时候打印出非叶节点梯度。当然,通过这个手段,你也可以自定义某一层的梯度更新方法。[3] 具体到这里的打印非叶节点的梯度,代码如:

def hook_y(grad):
 print(grad)

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3

y.register_hook(hook_y) 

out = z.mean()
out.backward()

输出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

retain_grad()

Tensor.retain_grad()显式地保存非叶节点的梯度,当然代价就是会增加显存的消耗,而用hook函数的方法则是在反向计算时直接打印,因此不会增加显存消耗,但是使用起来retain_grad()要比hook函数方便一些。代码如:

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
y.retain_grad()
z = y * y * 3
out = z.mean()
out.backward()
print(y.grad)

输出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

以上这篇在pytorch中对非叶节点的变量计算梯度实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现查询苹果手机维修进度
Mar 16 Python
几个提升Python运行效率的方法之间的对比
Apr 03 Python
Python中对元组和列表按条件进行排序的方法示例
Nov 10 Python
基于Python os模块常用命令介绍
Nov 03 Python
python按综合、销量排序抓取100页的淘宝商品列表信息
Feb 24 Python
Python+opencv 实现图片文字的分割的方法示例
Jul 04 Python
python-opencv获取二值图像轮廓及中心点坐标的代码
Aug 27 Python
python实现全排列代码(回溯、深度优先搜索)
Feb 26 Python
Python必须了解的35个关键词
Jul 16 Python
最新Python idle下载、安装与使用教程图文详解
Nov 28 Python
class类在python中获取金融数据的实例方法
Dec 10 Python
python爬取新闻门户网站的示例
Apr 25 Python
python如何获取apk的packagename和activity
Jan 10 #Python
浅谈pytorch卷积核大小的设置对全连接神经元的影响
Jan 10 #Python
python颜色随机生成器的实例代码
Jan 10 #Python
关于python pycharm中输出的内容不全的解决办法
Jan 10 #Python
Python GUI自动化实现绕过验证码登录
Jan 10 #Python
pytorch nn.Conv2d()中的padding以及输出大小方式
Jan 10 #Python
如何给Python代码进行加密
Jan 10 #Python
You might like
一台收音机,让一家人都笑逐颜开!
2020/08/21 无线电
用PHP+MySql编写聊天室
2006/10/09 PHP
PHP仿盗链代码
2012/06/03 PHP
PHP代码审核的详细介绍
2013/06/13 PHP
php中随机函数mt_rand()与rand()性能对比分析
2014/12/01 PHP
laravel5使用freetds连接sql server的方法
2018/12/07 PHP
PHP中Session ID的实现原理实例分析
2019/08/17 PHP
laravel中Redis队列监听中断的分析
2020/09/14 PHP
Javascript面向对象扩展库代码分享
2012/03/27 Javascript
JS添加删除一组文本框并对输入信息加以验证判断其正确性
2013/04/11 Javascript
纯JavaScript实现HTML5 Canvas六种特效滤镜示例
2013/06/28 Javascript
Mac/Windows下如何安装Node.js
2013/11/22 Javascript
比较不错的JS/JQuery显示或隐藏文本的方法
2014/02/13 Javascript
JavaScript中的Math 使用介绍
2014/04/21 Javascript
jQuery获取父元素节点、子元素节点及兄弟元素节点的方法
2016/04/14 Javascript
第三篇Bootstrap网格基础
2016/06/21 Javascript
Bootstrap BootstrapDialog使用详解
2017/02/17 Javascript
详解mpvue开发小程序小总结
2018/07/25 Javascript
Vue常用的几个指令附完整案例
2018/11/06 Javascript
JavaScript偏函数与柯里化实例详解
2019/03/27 Javascript
js中实现继承的五种方法
2021/01/25 Javascript
[03:01]2014DOTA2国际邀请赛 DC:我是核弹粉,为Burning和国土祝福
2014/07/13 DOTA
[49:13]DOTA2上海特级锦标赛C组资格赛#1 OG VS LGD第一局
2016/02/27 DOTA
[01:34]传奇从这开始 2016国际邀请赛中国区预选赛震撼开启
2016/06/26 DOTA
python实现简单socket程序在两台电脑之间传输消息的方法
2015/03/13 Python
Python模拟登录的多种方法(四种)
2018/06/01 Python
python计算两个地址之间的距离方法
2018/06/09 Python
Python3模拟curl发送post请求操作示例
2019/05/03 Python
Python之Numpy的超实用基础详细教程
2019/10/23 Python
详解HTML5常用的语义化标签
2019/09/27 HTML / CSS
什么是典型的软件三层结构?软件设计为什么要分层?软件分层有什么好处?
2012/03/14 面试题
如何找出EMP表里面SALARY第N高的employee
2013/12/05 面试题
介绍一下Linux中的链接
2016/06/05 面试题
电子商务专业应届生求职信
2014/05/28 职场文书
优秀班组长事迹
2014/05/31 职场文书
新郎接新娘保证书
2015/05/08 职场文书