Pytorch中的自动求梯度机制和Variable类实例


Posted in Python onFebruary 29, 2020

自动求导机制是每一个深度学习框架中重要的性质,免去了手动计算导数,下面用代码介绍并举例说明Pytorch的自动求导机制。

首先介绍Variable,Variable是对Tensor的一个封装,操作和Tensor是一样的,但是每个Variable都有三个属性:Varibale的Tensor本身的.data,对应Tensor的梯度.grad,以及这个Variable是通过什么方式得到的.grad_fn,根据最新消息,在pytorch0.4更新后,torch和torch.autograd.Variable现在是同一类。torch.Tensor能像Variable那样追踪历史和反向传播。Variable仍能正确工作,但是返回的是Tensor。

我们拥抱这些新特性,看看Pytorch怎么进行自动求梯度。

#encoding:utf-8
import torch

x = torch.tensor([2.],requires_grad=True) #新建一个tensor,允许自动求梯度,这一项默认是false.
y = (x+2)**2 + 3 #y的表达式中包含x,因此y能进行自动求梯度
y.backward()
print(x.grad)

输出结果是:

tensor([8.])

这里添加一个小知识点,即torch.Tensor和torch.tensor的不同。二者均可以生成新的张量,但torch.Tensor()是python类,是默认张量类型torch.FloatTensor()的别名,使用torch.Tensor()会调用构造函数,生成单精度浮点类型的张量。

而torch.tensor()是函数,其中data可以是list,tuple,numpy,ndarray,scalar和其他类型,但只有浮点类型的张量能够自动求梯度。

torch.tensor(data, dtype=None, device=None, requires_grad=False)

言归正传,上一个例子的变量本质上是标量。下面一个例子对矩阵求导。

#encoding:utf-8
import torch

x = torch.ones((2,4),requires_grad=True)
y = torch.ones((2,1),requires_grad=True)
W = torch.ones((4,1),requires_grad=True)

J = torch.sum(y - torch.matmul(x,W)) #torch.matmul()表示对矩阵作乘法
J.backward()
print(x.grad)
print(y.grad)
print(W.grad)

输出结果是:

tensor([[-1., -1., -1., -1.],
   [-1., -1., -1., -1.]])
tensor([[1.],
   [1.]])
tensor([[-2.],
   [-2.],
   [-2.],
   [-2.]])

以上这篇Pytorch中的自动求梯度机制和Variable类实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
django 解决manage.py migrate无效的问题
May 27 Python
Python设计模式之外观模式实例详解
Jan 17 Python
python实现定时发送qq消息
Jan 18 Python
Ubuntu下Python+Flask分分钟搭建自己的服务器教程
Nov 19 Python
Python实现结构体代码实例
Feb 10 Python
Python request使用方法及问题总结
Apr 26 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 Python
使用Keras构造简单的CNN网络实例
Jun 29 Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 Python
浅谈Django前端后端值传递问题
Jul 15 Python
python中字符串的编码与解码详析
Dec 03 Python
Python爬虫之爬取某文库文档数据
Apr 21 Python
在pytorch中实现只让指定变量向后传播梯度
Feb 29 #Python
浅谈Pytorch中的自动求导函数backward()所需参数的含义
Feb 29 #Python
python数据预处理 :样本分布不均的解决(过采样和欠采样)
Feb 29 #Python
python实现门限回归方式
Feb 29 #Python
Python3.9又更新了:dict内置新功能
Feb 28 #Python
python实现logistic分类算法代码
Feb 28 #Python
python GUI库图形界面开发之PyQt5打印控件QPrinter详细使用方法与实例
Feb 28 #Python
You might like
什么是短波收听SWL
2021/03/01 无线电
MySql中正则表达式的使用方法描述
2008/07/30 PHP
rephactor 优秀的PHP的重构工具
2011/06/09 PHP
PHP中strtotime函数使用方法分享
2012/01/10 PHP
php实现excel中rank函数功能的方法
2015/01/20 PHP
PHP常见数组函数用法小结
2016/03/21 PHP
php实现JWT(json web token)鉴权实例详解
2019/11/05 PHP
用javascript获取地址栏参数
2006/12/22 Javascript
JavaScript的9个陷阱及评点分析
2008/05/16 Javascript
jquery实现类似淘宝星星评分功能实例
2014/09/12 Javascript
JavaScript中split() 使用方法汇总
2015/04/17 Javascript
原生js代码实现图片放大境效果
2016/10/30 Javascript
使用Vue.js创建一个时间跟踪的单页应用
2016/11/28 Javascript
thinkphp标签实现bootsrtap轮播carousel实例代码
2017/02/19 Javascript
vue动态生成dom并且自动绑定事件
2017/04/19 Javascript
vue.js开发环境搭建教程
2017/05/04 Javascript
jQuery Validate格式验证功能实例代码(包括重名验证)
2017/07/18 jQuery
jQuery zTree 异步加载添加子节点重复问题
2017/11/29 jQuery
AngularJS集合数据遍历显示的实例
2017/12/27 Javascript
JS实现判断移动端PC端功能
2020/02/21 Javascript
原生js实现简单轮播图
2020/10/26 Javascript
[40:27]完美世界DOTA2联赛PWL S3 PXG vs GXR 第一场 12.19
2020/12/24 DOTA
实例讲解Python编程中@property装饰器的用法
2016/06/20 Python
python中urlparse模块介绍与使用示例
2017/11/19 Python
pycharm 主题theme设置调整仿sublime的方法
2018/05/23 Python
python PrettyTable模块的安装与简单应用
2019/01/11 Python
Django组件cookie与session的具体使用
2019/06/05 Python
用Pelican搭建一个极简静态博客系统过程解析
2019/08/22 Python
python 将dicom图片转换成jpg图片的实例
2020/01/13 Python
Python创建空列表的字典2种方法详解
2020/02/13 Python
Win10下安装并使用tensorflow-gpu1.8.0+python3.6全过程分析(显卡MX250+CUDA9.0+cudnn)
2020/02/17 Python
win10下opencv-python特定版本手动安装与pip自动安装教程
2020/03/05 Python
法国一家芭蕾舞鞋公司:Repetto
2018/11/12 全球购物
墨尔本最受欢迎的复古风格品牌:Princess Highway
2018/12/21 全球购物
公共汽车、火车和飞机票的通用在线预订和销售平台:INFOBUS
2019/11/30 全球购物
工伤调解协议书
2016/03/21 职场文书