解决torch.autograd.backward中的参数问题


Posted in Python onJanuary 07, 2020

torch.autograd.backward(variables, grad_variables=None, retain_graph=None, create_graph=False)

给定图的叶子节点variables, 计算图中变量的梯度和。 计算图可以通过链式法则求导。如果variables中的任何一个variable是 非标量(non-scalar)的,且requires_grad=True。那么此函数需要指定grad_variables,它的长度应该和variables的长度匹配,里面保存了相关variable的梯度(对于不需要gradient tensor的variable,None是可取的)。

此函数累积leaf variables计算的梯度。你可能需要在调用此函数之前将leaf variable的梯度置零。

参数:

variables(变量的序列) - 被求微分的叶子节点,即 ys 。

grad_variables((张量,变量)的序列或无) - 对应variable的梯度。仅当variable不是标量且需要求梯度的时候使用。

retain_graph(bool,可选) - 如果为False,则用于释放计算grad的图。请注意,在几乎所有情况下,没有必要将此选项设置为True,通常可以以更有效的方式解决。默认值为create_graph的值。

create_graph(bool,可选) - 如果为True,则将构造派生图,允许计算更高阶的派生产品。默认为False。

我这里举一个官方的例子

import torch
from torch.autograd import Variable
x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()
out.backward()#这里是默认情况,相当于out.backward(torch.Tensor([1.0]))
print(x.grad)

输出结果是

Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]

解决torch.autograd.backward中的参数问题

接着我们继续

x = torch.randn(3)
x = Variable(x, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
  y = y * 2

gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)

输出结果是

Variable containing:
 204.8000
 2048.0000
  0.2048
[torch.FloatTensor of size 3]

这里这个gradients为什么要是[0.1, 1.0, 0.0001]?

如果输出的多个loss权重不同的话,例如有三个loss,一个是x loss,一个是y loss,一个是class loss。那么很明显的不可能所有loss对结果影响程度都一样,他们之间应该有一个比例。那么比例这里指的就是[0.1, 1.0, 0.0001],这个问题中的loss对应的就是上面说的y,那么这里的输出就很好理解了dy/dx=0.1*dy1/dx+1.0*dy2/dx+0.0001*dy3/dx。

如有问题,希望大家指正,谢谢_!

以上这篇解决torch.autograd.backward中的参数问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python性能优化的20条建议
Oct 25 Python
几个提升Python运行效率的方法之间的对比
Apr 03 Python
python并发编程之多进程、多线程、异步和协程详解
Oct 28 Python
python GUI实例学习
Nov 21 Python
Python实现求解括号匹配问题的方法
Apr 17 Python
python用pandas数据加载、存储与文件格式的实例
Dec 07 Python
python 通过视频url获取视频的宽高方式
Dec 10 Python
python默认参数调用方法解析
Feb 09 Python
pytorch 模型的train模式与eval模式实例
Feb 20 Python
python读取excel数据并且画图的实现示例
Feb 08 Python
python使用tkinter实现透明窗体上绘制随机出现的小球(实例代码)
May 17 Python
Python自动操作神器PyAutoGUI的使用教程
Jun 16 Python
Pytorch 中retain_graph的用法详解
Jan 07 #Python
PyTorch中的Variable变量详解
Jan 07 #Python
python enumerate内置函数用法总结
Jan 07 #Python
pytorch加载自定义网络权重的实现
Jan 07 #Python
Matplotlib绘制雷达图和三维图的示例代码
Jan 07 #Python
Pytorch 神经网络—自定义数据集上实现教程
Jan 07 #Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
You might like
php实现四舍五入的方法小结
2015/03/03 PHP
PHP超全局数组(Superglobals)介绍
2015/07/01 PHP
详解WordPress开发中过滤属性以及Sql语句的函数使用
2015/12/25 PHP
PHP中TP5 上传文件的实例详解
2017/07/31 PHP
CentOS7系统搭建LAMP及更新PHP版本操作详解
2020/03/26 PHP
一个js的tab切换效果代码[代码分离]
2010/04/11 Javascript
13个绚丽的Jquery 界面设计网站推荐
2010/09/28 Javascript
两个数组去重的JS代码
2013/12/04 Javascript
JavaScript实现简单图片翻转的方法
2015/04/17 Javascript
Node.js实现JS文件合并小工具
2016/02/02 Javascript
JavaScript从0开始构思表情插件
2016/07/26 Javascript
js实现随机抽选效果、随机抽选红色球效果
2017/01/13 Javascript
JS仿QQ好友列表展开、收缩功能(第一篇)
2017/07/07 Javascript
关于react-router的几种配置方式详解
2017/07/24 Javascript
关于TypeScript中import JSON的正确姿势详解
2017/07/25 Javascript
js中call()和apply()改变指针问题的讲解
2019/01/17 Javascript
配置eslint规范项目代码风格
2019/03/11 Javascript
Vue插件之滑动验证码用法详解
2020/04/05 Javascript
[01:18:36]LGD vs VP Supermajor 败者组决赛 BO3 第一场 6.10
2018/07/04 DOTA
win7 下搭建sublime的python开发环境的配置方法
2014/06/18 Python
Python序列之list和tuple常用方法以及注意事项
2015/01/09 Python
win与linux系统中python requests 安装
2016/12/04 Python
Python使用PIL模块生成随机验证码
2017/11/21 Python
通过Py2exe将自己的python程序打包成.exe/.app的方法
2018/05/26 Python
对python中的iter()函数与next()函数详解
2018/10/18 Python
Python3.5集合及其常见运算实例详解
2019/05/01 Python
python与C、C++混编的四种方式(小结)
2019/07/15 Python
Python实现上下文管理器的方法
2020/08/07 Python
携程英文网站:Trip.com
2017/02/07 全球购物
Skyscanner波兰:廉价航班
2017/11/07 全球购物
营销主管自我评价怎么写
2013/09/19 职场文书
50岁生日感言
2014/01/23 职场文书
党委书记个人对照检查材料
2014/09/15 职场文书
2016庆祝教师节新闻稿
2015/11/25 职场文书
2016年基层党组织创先争优承诺书
2016/03/25 职场文书
关于MySQL临时表为什么可以重名的问题
2022/03/22 MySQL