浅谈对pytroch中torch.autograd.backward的思考


Posted in Python onDecember 27, 2019

反向传递法则是深度学习中最为重要的一部分,torch中的backward可以对计算图中的梯度进行计算和累积

这里通过一段程序来演示基本的backward操作以及需要注意的地方

>>> import torch
>>> from torch.autograd import Variable

>>> x = Variable(torch.ones(2,2), requires_grad=True)
>>> y = x + 2
>>> y.grad_fn
Out[6]: <torch.autograd.function.AddConstantBackward at 0x229e7068138>
>>> y.grad

>>> z = y*y*3
>>> z.grad_fn
Out[9]: <torch.autograd.function.MulConstantBackward at 0x229e86cc5e8>
>>> z
Out[10]: 
Variable containing:
 27 27
 27 27
[torch.FloatTensor of size 2x2]
>>> out = z.mean()
>>> out.grad_fn
Out[12]: <torch.autograd.function.MeanBackward at 0x229e86cc408>
>>> out.backward()   # 这里因为out为scalar标量,所以参数不需要填写
>>> x.grad
Out[19]: 
Variable containing:
 4.5000 4.5000
 4.5000 4.5000
[torch.FloatTensor of size 2x2]
>>> out  # out为标量
Out[20]: 
Variable containing:
 27
[torch.FloatTensor of size 1]

>>> x = Variable(torch.Tensor([2,2,2]), requires_grad=True)
>>> y = x*2
>>> y
Out[52]: 
Variable containing:
 4
 4
 4
[torch.FloatTensor of size 3]
>>> y.backward() # 因为y输出为非标量,求向量间元素的梯度需要对所求的元素进行标注,用相同长度的序列进行标注
Traceback (most recent call last):
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\IPython\core\interactiveshell.py", line 2862, in run_code
  exec(code_obj, self.user_global_ns, self.user_ns)
 File "<ipython-input-53-95acac9c3254>", line 1, in <module>
  y.backward()
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\variable.py", line 156, in backward
  torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 86, in backward
  grad_variables, create_graph = _make_grads(variables, grad_variables, create_graph)
 File "C:\Users\dell\Anaconda3\envs\my-pytorch\lib\site-packages\torch\autograd\__init__.py", line 34, in _make_grads
  raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        #注意这里的0.1,1.10为梯度求值比例
Out[55]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad        # 梯度累积
Out[57]: 
Variable containing:
 0.4000
 4.0000
 40.0000
[torch.FloatTensor of size 3]

>>> x.grad.data.zero_() # 梯度累积进行清零
Out[60]: 
 0
 0
 0
[torch.FloatTensor of size 3]
>>> x.grad       # 累积为空
Out[61]: 
Variable containing:
 0
 0
 0
[torch.FloatTensor of size 3]
>>> y.backward(torch.FloatTensor([0.1, 1, 10]))
>>> x.grad
Out[63]: 
Variable containing:
 0.2000
 2.0000
 20.0000
[torch.FloatTensor of size 3]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在windows系统中实现python3安装lxml
Mar 23 Python
Python抓取框架Scrapy爬虫入门:页面提取
Dec 01 Python
Python产生Gnuplot绘图数据的方法
Nov 09 Python
Python3删除排序数组中重复项的方法分析
Jan 31 Python
Python函数中不定长参数的写法
Feb 13 Python
Python:Numpy 求平均向量的实例
Jun 29 Python
Python 实现黑客帝国中的字符雨的示例代码
Feb 20 Python
详解Python中namedtuple的使用
Apr 27 Python
Python根据指定文件生成XML的方法
Jun 29 Python
python利用proxybroker构建爬虫免费IP代理池的实现
Feb 21 Python
写一个Python脚本下载哔哩哔哩舞蹈区的所有视频
May 31 Python
Python实现byte转integer
Jun 03 Python
python实现异常信息堆栈输出到日志文件
Dec 26 #Python
Python的对象传递与Copy函数使用详解
Dec 26 #Python
Python pandas库中的isnull()详解
Dec 26 #Python
python dataframe NaN处理方式
Dec 26 #Python
python实现大战外星人小游戏实例代码
Dec 26 #Python
Python数据存储之 h5py详解
Dec 26 #Python
Python 使用 prettytable 库打印表格美化输出功能
Dec 26 #Python
You might like
短波的认识
2021/03/01 无线电
PHP中的正则表达式函数介绍
2012/02/27 PHP
Laravel (Lumen) 解决JWT-Auth刷新token的问题
2019/10/24 PHP
PHP cookie与session会话基本用法实例分析
2019/11/18 PHP
javascript之卸载鼠标事件的代码
2007/05/14 Javascript
关于Javascript作用域链的八点总结
2013/12/06 Javascript
js与jquery实时监听输入框值的oninput与onpropertychange方法
2015/02/05 Javascript
js实现发送验证码后的倒计时功能
2015/05/28 Javascript
AngularJs ng-repeat 嵌套如何获取外层$index
2016/09/21 Javascript
js中的触发事件对象event.srcElement与event.target详解
2017/03/15 Javascript
js中作用域的实例解析
2017/03/16 Javascript
JS组件系列之MVVM组件构建自己的Vue组件
2017/04/28 Javascript
Bootstrap Table 删除和批量删除
2017/09/22 Javascript
JavaScript框架Angular和React深度对比
2017/11/20 Javascript
解决Vue+Electron下Vuex的Dispatch没有效果问题
2019/05/20 Javascript
vue 使用高德地图vue-amap组件过程解析
2019/09/07 Javascript
VUE+Element实现增删改查的示例源码
2020/11/23 Vue.js
Javascript实现单选框效果
2020/12/09 Javascript
[01:38]【DOTA2亚洲邀请赛】Sumail——梦开始的地方
2017/03/03 DOTA
[03:55]DOTA2完美大师赛选手传记——LFY.MONET
2017/11/18 DOTA
Python实现拼接多张图片的方法
2014/12/01 Python
在Python的框架中为MySQL实现restful接口的教程
2015/04/08 Python
解析Python编程中的包结构
2015/10/25 Python
python3爬虫获取html内容及各属性值的方法
2018/12/17 Python
详解Django-restframework 之频率源码分析
2019/02/27 Python
详解使用PyInstaller将Pygame库编写的小游戏程序打包为exe文件
2019/08/23 Python
Python从列表推导到zip()函数的5种技巧总结
2019/10/23 Python
如何提高python 中for循环的效率
2020/04/15 Python
FOREO官方网站:LUNA露娜洁面仪
2016/11/28 全球购物
中国跨境电子商务网站:NewFrog
2018/03/10 全球购物
女装和独特珠宝:Sundance Catalog
2018/09/19 全球购物
shallow copy和deep copy的区别
2016/05/09 面试题
心理健康教育制度
2014/01/27 职场文书
学校标语口号大全
2015/12/26 职场文书
Nginx本地目录映射实现代码实例
2021/03/31 Servers
Python包管理工具pip的15 个使用小技巧
2021/05/17 Python