浅谈对pytroch中torch.autograd.backward的思考


Posted in Python onDecember 27, 2019

反向传递法则是深度学习中最为重要的一部分,torch中的backward可以对计算图中的梯度进行计算和累积

这里通过一段程序来演示基本的backward操作以及需要注意的地方

>>> import torch
>>> from torch.autograd import Variable

>>> x = Variable(torch.ones(2,2), requires_grad=True)
>>> y = x + 2
>>> y.grad_fn
Out[6]: <torch.autograd.function.AddConstantBackward at 0x229e7068138>
>>> y.grad

>>> z = y*y*3
>>> z.grad_fn
Out[9]: <torch.autograd.function.MulConstantBackward at 0x229e86cc5e8>
>>> z
Out[10]: 
Variable containing:
 27 27
 27 27
[torch.FloatTensor of size 2x2]
>>> out = z.mean()
>>> out.grad_fn
Out[12]: <torch.autograd.function.MeanBackward at 0x229e86cc408>
>>> out.backward()   # 这里因为out为scalar标量,所以参数不需要填写
>>> x.grad
Out[19]: 
Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]
>>> out  # out为标量
Out[20]: 
Variable containing:
 27
[torch.FloatTensor of size 1]

>>> x = Variable(torch.Tensor([2,2,2]), requires_grad=True)
>>> y = x*2
>>> y
Out[52]: 
Variable containing:
 4
 4
 4
[torch.FloatTensor of size 3]
>>> y.backward() # 因为y输出为非标量,求向量间元素的梯度需要对所求的元素进行标注,用相同长度的序列进行标注
Traceback (most recent call last):
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\IPython\core\interactiveshell.py", line 2862, in run_code
  exec(code_obj, self.user_global_ns, self.user_ns)
 File "<ipython-input-53-95acac9c3254>", line 1, in <module>
  y.backward()
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\variable.py", line 156, in backward
  torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 86, in backward
  grad_variables, create_graph = _make_grads(variables, grad_variables, create_graph)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 34, in _make_grads
  raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        #注意这里的0.1,1.10为梯度求值比例
Out[55]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        # 梯度累积
Out[57]: 
Variable containing:
 0.4000
 4.0000
 40.0000
[torch.FloatTensor of size 3]

>>> x.grad.data.zero_() # 梯度累积进行清零
Out[60]: 
 0
 0
 0
[torch.FloatTensor of size 3]
>>> x.grad       # 累积为空
Out[61]: 
Variable containing:
 0
 0
 0
[torch.FloatTensor of size 3]
>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad
Out[63]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3+dlib实现人脸识别和情绪分析
Apr 21 Python
Python流行ORM框架sqlalchemy安装与使用教程
Jun 04 Python
python f-string式格式化听语音流程讲解
Jun 18 Python
简单瞅瞅Python vars()内置函数的实现
Sep 27 Python
解决django后台管理界面添加中文内容乱码问题
Nov 15 Python
Python+OpenCV 实现图片无损旋转90°且无黑边
Dec 12 Python
JetBrains PyCharm(Community版本)的下载、安装和初步使用图文教程详解
Mar 19 Python
Python SQLAlchemy库的使用方法
Oct 13 Python
python生成随机数、随机字符、随机字符串
Apr 06 Python
Python机器学习之PCA降维算法详解
May 19 Python
pytorch训练神经网络爆内存的解决方案
May 22 Python
Python  lambda匿名函数和三元运算符
Apr 19 Python
python实现异常信息堆栈输出到日志文件
Dec 26 #Python
Python的对象传递与Copy函数使用详解
Dec 26 #Python
Python pandas库中的isnull()详解
Dec 26 #Python
python dataframe NaN处理方式
Dec 26 #Python
python实现大战外星人小游戏实例代码
Dec 26 #Python
Python数据存储之 h5py详解
Dec 26 #Python
Python 使用 prettytable 库打印表格美化输出功能
Dec 26 #Python
You might like
一个SQL管理员的web接口
2006/10/09 PHP
PHP4和PHP5共存于一系统
2006/11/17 PHP
PHP中利用substr_replace将指定两位置之间的字符替换为*号
2011/01/27 PHP
php中switch与ifelse的效率区别及适用情况分析
2015/02/12 PHP
Thinkphp通过一个入口文件如何区分移动端和PC端
2017/04/18 PHP
PHP实现表单提交数据的验证处理功能【防SQL注入和XSS攻击等】
2017/07/21 PHP
PHP FileSystem 文件系统常用api整理总结
2019/07/12 PHP
javascript getElementsByClassName 和js取地址栏参数
2010/01/02 Javascript
javascript将数字转换整数金额大写的方法
2015/01/27 Javascript
jQuery监控文本框事件并作相应处理的方法
2015/04/16 Javascript
jQuery Ajax 上传文件处理方式介绍(推荐)
2016/06/30 Javascript
自己封装的一个原生JS拖动方法(推荐)
2016/11/22 Javascript
鼠标经过出现气泡框的简单实例
2017/03/17 Javascript
ES6(ECMAScript 6)新特性之模板字符串用法分析
2017/04/01 Javascript
layerUI下的绑定事件实例代码
2018/08/17 Javascript
javascript实现切割轮播效果
2019/11/28 Javascript
vue 添加和编辑用同一个表单,el-form表单提交后清空表单数据操作
2020/08/03 Javascript
JavaScript实现网页下拉菜单效果
2020/11/20 Javascript
pandas groupby 分组取每组的前几行记录方法
2018/04/20 Python
在dataframe两列日期相减并且得到具体的月数实例
2018/07/03 Python
Python创建字典的八种方式
2019/02/27 Python
Python爬虫:将headers请求头字符串转为字典的方法
2019/08/21 Python
python实现计算器功能
2019/10/31 Python
pytorch标签转onehot形式实例
2020/01/02 Python
Python基础之字符串常见操作经典实例详解
2020/02/26 Python
台湾最大银发乐活百货:乐龄网
2018/05/21 全球购物
会计专业自荐信范文
2013/12/02 职场文书
高二地理教学反思
2014/01/24 职场文书
孩子教育的心得体会
2014/09/01 职场文书
中学生打架检讨书
2014/10/13 职场文书
2014五年级班主任工作总结
2014/12/05 职场文书
行政处罚听证告知书
2015/07/01 职场文书
《穷人》教学反思
2016/02/19 职场文书
某学校的2019年度工作报告范本
2019/10/11 职场文书
如何使用flask将模型部署为服务
2021/05/13 Python
js作用域及作用域链工作引擎
2022/07/07 Javascript