在pytorch中对非叶节点的变量计算梯度实例


Posted in Python onJanuary 10, 2020

在pytorch中一般只对叶节点进行梯度计算,也就是下图中的d,e节点,而对非叶节点,也即是c,b节点则没有显式地去保留其中间计算过程中的梯度(因为一般来说只有叶节点才需要去更新),这样可以节省很大部分的显存,但是在调试过程中,有时候我们需要对中间变量梯度进行监控,以确保网络的有效性,这个时候我们需要打印出非叶节点的梯度,为了实现这个目的,我们可以通过两种手段进行。

在pytorch中对非叶节点的变量计算梯度实例

注册hook函数

Tensor.register_hook[2] 可以注册一个反向梯度传导时的hook函数,这个hook函数将会在每次计算 关于该张量 在pytorch中对非叶节点的变量计算梯度实例 的时候 被调用,经常用于调试的时候打印出非叶节点梯度。当然,通过这个手段,你也可以自定义某一层的梯度更新方法。[3] 具体到这里的打印非叶节点的梯度,代码如:

def hook_y(grad):
 print(grad)

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3

y.register_hook(hook_y) 

out = z.mean()
out.backward()

输出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

retain_grad()

Tensor.retain_grad()显式地保存非叶节点的梯度,当然代价就是会增加显存的消耗,而用hook函数的方法则是在反向计算时直接打印,因此不会增加显存消耗,但是使用起来retain_grad()要比hook函数方便一些。代码如:

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
y.retain_grad()
z = y * y * 3
out = z.mean()
out.backward()
print(y.grad)

输出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

以上这篇在pytorch中对非叶节点的变量计算梯度实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python模块restful使用方法实例
Dec 10 Python
centos 下面安装python2.7 +pip +mysqld
Nov 18 Python
python机器学习案例教程——K最近邻算法的实现
Dec 28 Python
Python操作Sql Server 2008数据库的方法详解
May 17 Python
Python函数any()和all()的用法及区别介绍
Sep 14 Python
Python Django 添加首页尾页上一页下一页代码实例
Aug 21 Python
python获取array中指定元素的示例
Nov 26 Python
win10安装tensorflow-gpu1.8.0详细完整步骤
Jan 20 Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 Python
浅析Python requests 模块
Oct 09 Python
Python趣味爬虫之用Python实现智慧校园一键评教
May 28 Python
python套接字socket通信
Apr 01 Python
python如何获取apk的packagename和activity
Jan 10 #Python
浅谈pytorch卷积核大小的设置对全连接神经元的影响
Jan 10 #Python
python颜色随机生成器的实例代码
Jan 10 #Python
关于python pycharm中输出的内容不全的解决办法
Jan 10 #Python
Python GUI自动化实现绕过验证码登录
Jan 10 #Python
pytorch nn.Conv2d()中的padding以及输出大小方式
Jan 10 #Python
如何给Python代码进行加密
Jan 10 #Python
You might like
PHP 读取和修改大文件的某行内容的代码
2009/10/30 PHP
提高define性能的php扩展hidef的安装和使用
2011/06/14 PHP
PHP中的use关键字概述
2014/07/23 PHP
PHP实现的比较完善的购物车类
2014/12/02 PHP
PHP 微信扫码支付源代码(推荐)
2016/11/03 PHP
php 读写json文件及修改json的方法
2018/03/07 PHP
laravel5使用freetds连接sql server的方法
2018/12/07 PHP
js类中的公有变量和私有变量
2008/07/24 Javascript
jquery 实现checkbox全选,反选,全不选等功能代码(奇数)
2012/10/24 Javascript
js去除输入框中所有的空格和禁止输入空格的方法
2014/06/09 Javascript
JS输入用户名自动显示邮箱后缀列表的方法
2015/01/27 Javascript
JavaScript获取伪元素(Pseudo-Element)属性的方法技巧
2015/03/13 Javascript
基于JavaScript实现通用tab选项卡(通用性强)
2016/01/07 Javascript
简述Matlab中size()函数的用法
2016/03/20 Javascript
浅谈JavaScript事件绑定的常用方法及其优缺点分析
2016/11/01 Javascript
微信小程序中用WebStorm使用LESS
2017/03/08 Javascript
ES6字符串模板,剩余参数,默认参数功能与用法示例
2017/04/06 Javascript
获取url中用&隔开的参数实例(分享)
2017/05/28 Javascript
JS对象序列化成json数据和json数据转化为JS对象的代码
2017/08/23 Javascript
微信小程序如何自定义table组件
2019/06/29 Javascript
vue项目中极验验证的使用代码示例
2019/12/03 Javascript
vue解决跨域问题(推荐)
2020/11/10 Javascript
[48:31]DOTA2-DPC中国联赛 正赛 Dynasty vs XG BO3 第一场 2月2日
2021/03/11 DOTA
介绍Python的@property装饰器的用法
2015/04/28 Python
利用Pandas读取文件路径或文件名称包含中文的csv文件方法
2018/07/04 Python
Python读写文件基础知识点
2019/06/10 Python
django框架实现一次性上传多个文件功能示例【批量上传】
2019/06/19 Python
pytorch实现线性拟合方式
2020/01/15 Python
详解PyQt5中textBrowser显示print语句输出的简单方法
2020/08/07 Python
超市创业计划书
2014/09/15 职场文书
中职毕业生自我鉴定范文(3篇)
2014/09/28 职场文书
2014年教学管理工作总结
2014/12/02 职场文书
银行安全保卫工作总结
2015/08/10 职场文书
2016年社区国庆节活动总结
2016/04/01 职场文书
中国古代史学名著《战国策》概述
2019/08/09 职场文书
springboot中rabbitmq实现消息可靠性机制详解
2021/09/25 Java/Android