浅谈对pytroch中torch.autograd.backward的思考


Posted in Python onDecember 27, 2019

反向传递法则是深度学习中最为重要的一部分,torch中的backward可以对计算图中的梯度进行计算和累积

这里通过一段程序来演示基本的backward操作以及需要注意的地方

>>> import torch
>>> from torch.autograd import Variable

>>> x = Variable(torch.ones(2,2), requires_grad=True)
>>> y = x + 2
>>> y.grad_fn
Out[6]: <torch.autograd.function.AddConstantBackward at 0x229e7068138>
>>> y.grad

>>> z = y*y*3
>>> z.grad_fn
Out[9]: <torch.autograd.function.MulConstantBackward at 0x229e86cc5e8>
>>> z
Out[10]: 
Variable containing:
 27 27
 27 27
[torch.FloatTensor of size 2x2]
>>> out = z.mean()
>>> out.grad_fn
Out[12]: <torch.autograd.function.MeanBackward at 0x229e86cc408>
>>> out.backward()   # 这里因为out为scalar标量,所以参数不需要填写
>>> x.grad
Out[19]: 
Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]
>>> out  # out为标量
Out[20]: 
Variable containing:
 27
[torch.FloatTensor of size 1]

>>> x = Variable(torch.Tensor([2,2,2]), requires_grad=True)
>>> y = x*2
>>> y
Out[52]: 
Variable containing:
 4
 4
 4
[torch.FloatTensor of size 3]
>>> y.backward() # 因为y输出为非标量,求向量间元素的梯度需要对所求的元素进行标注,用相同长度的序列进行标注
Traceback (most recent call last):
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\IPython\core\interactiveshell.py", line 2862, in run_code
  exec(code_obj, self.user_global_ns, self.user_ns)
 File "<ipython-input-53-95acac9c3254>", line 1, in <module>
  y.backward()
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\variable.py", line 156, in backward
  torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 86, in backward
  grad_variables, create_graph = _make_grads(variables, grad_variables, create_graph)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 34, in _make_grads
  raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        #注意这里的0.1,1.10为梯度求值比例
Out[55]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        # 梯度累积
Out[57]: 
Variable containing:
 0.4000
 4.0000
 40.0000
[torch.FloatTensor of size 3]

>>> x.grad.data.zero_() # 梯度累积进行清零
Out[60]: 
 0
 0
 0
[torch.FloatTensor of size 3]
>>> x.grad       # 累积为空
Out[61]: 
Variable containing:
 0
 0
 0
[torch.FloatTensor of size 3]
>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad
Out[63]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python MySQLdb模块连接操作mysql数据库实例
Apr 08 Python
Django中模版的子目录与include标签的使用方法
Jul 16 Python
Linux中Python 环境软件包安装步骤
Mar 31 Python
Python探索之修改Python搜索路径
Oct 25 Python
[原创]教女朋友学Python(一)运行环境搭建
Nov 29 Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 Python
详解python使用turtle库来画一朵花
Mar 21 Python
Python数据类型之Set集合实例详解
May 07 Python
Python数据结构dict常用操作代码实例
Mar 12 Python
python使用Thread的setDaemon启动后台线程教程
Apr 25 Python
简单介绍一下pyinstaller打包以及安全性的实现
Jun 02 Python
详解Python如何批量采集京东商品数据流程
Jan 22 Python
python实现异常信息堆栈输出到日志文件
Dec 26 #Python
Python的对象传递与Copy函数使用详解
Dec 26 #Python
Python pandas库中的isnull()详解
Dec 26 #Python
python dataframe NaN处理方式
Dec 26 #Python
python实现大战外星人小游戏实例代码
Dec 26 #Python
Python数据存储之 h5py详解
Dec 26 #Python
Python 使用 prettytable 库打印表格美化输出功能
Dec 26 #Python
You might like
Sample script that deletes a SQL Server database
2007/06/16 Javascript
用js实现计算代码行数的简单方法附代码
2007/08/13 Javascript
javascript中的变量作用域以及变量提升详细介绍
2013/10/24 Javascript
JS常用函数使用指南
2014/11/23 Javascript
jQuery 中DOM 操作详解
2015/01/13 Javascript
jQuery中的pushStack实现原理和应用实例
2015/02/03 Javascript
九种原生js动画效果
2015/11/11 Javascript
jQuery ajax 当async为false时解决同步操作失败的问题
2016/11/18 Javascript
jquery validation验证表单插件
2017/01/07 Javascript
angularjs实现多张图片上传并预览功能
2017/02/24 Javascript
使用Browserify来实现CommonJS的浏览器加载方法
2017/05/14 Javascript
js处理包含中文的字符串实例
2017/10/11 Javascript
vue.js打包之后可能会遇到的坑!
2018/06/03 Javascript
Vue内部渲染视图的方法
2019/09/02 Javascript
JavaScript获取当前url路径过程解析
2019/12/27 Javascript
js实现列表向上无限滚动
2020/01/13 Javascript
[02:27]2018DOTA2亚洲邀请赛赛前采访-OpTic
2018/04/03 DOTA
Python中的yield浅析
2014/06/16 Python
解决windows下Sublime Text 2 运行 PyQt 不显示的方法分享
2014/06/18 Python
python中List的sort方法指南
2014/09/01 Python
python实现简易数码时钟
2021/02/19 Python
Python对接支付宝支付自实现功能
2019/10/10 Python
文件上传服务器-jupyter 中python解压及压缩方式
2020/04/22 Python
CSS3实现时间轴效果
2016/07/11 HTML / CSS
联想墨西哥官方网站:Lenovo墨西哥
2016/08/17 全球购物
英国领先的新鲜松露和最好的松露产品供应商:TruffleHunter
2019/08/26 全球购物
澳大利亚著名的纺织品品牌:Canningvale
2020/05/05 全球购物
输入N,打印N*N矩阵
2012/02/20 面试题
几个Linux面试题笔试题
2012/12/01 面试题
移动通信专业自荐信范文
2013/11/12 职场文书
酒店执行总经理岗位职责
2013/12/15 职场文书
大学班级干部的自我评价分享
2014/02/10 职场文书
招商银行工作证明
2015/06/17 职场文书
2016年教师政治思想表现评语
2015/12/02 职场文书
数据库连接池
2021/04/06 MySQL
Python 如何将integer转化为罗马数(3999以内)
2021/06/05 Python