PyTorch 如何检查模型梯度是否可导


Posted in Python onJune 05, 2021

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

PyTorch 如何检查模型梯度是否可导

PyTorch 如何检查模型梯度是否可导

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

PyTorch 如何检查模型梯度是否可导

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

PyTorch 如何检查模型梯度是否可导PyTorch 如何检查模型梯度是否可导

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python查询Mysql时返回字典结构的代码
Jun 18 Python
python threading模块操作多线程介绍
Apr 08 Python
儿童python练习实例
May 27 Python
Python os.rename() 重命名目录和文件的示例
Oct 25 Python
python+PyQT实现系统桌面时钟
Jun 16 Python
Django+JS 实现点击头像即可更改头像的方法示例
Dec 26 Python
python定时按日期备份MySQL数据并压缩
Apr 19 Python
python虚拟环境完美部署教程
Aug 06 Python
django认证系统实现自定义权限管理的方法
Aug 28 Python
python实现简易淘宝购物
Nov 22 Python
如何基于windows实现python定时爬虫
May 01 Python
在python里使用await关键字来等另外一个协程的实例
May 04 Python
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
You might like
深入PHP5中的魔术方法详解
2013/06/17 PHP
ThinkPHP 表单自动验证运用示例
2014/10/13 PHP
[原创]php集成安装包wampserver修改密码后phpmyadmin无法登陆的解决方法
2016/11/23 PHP
Thinkphp框架 表单自动验证登录注册 ajax自动验证登录注册
2016/12/27 PHP
Javascript中的Split使用方法与技巧
2007/03/09 Javascript
JSON 数据格式介绍
2012/01/13 Javascript
jQuery的slideToggle方法实例
2013/05/07 Javascript
jQuery 设置 CSS 属性示例介绍
2014/01/16 Javascript
js实现的标题栏新消息闪烁提示效果
2014/06/06 Javascript
jfinal与bootstrap的登录跳转实战演习
2015/09/22 Javascript
Nodejs Express4.x开发框架随手笔记
2015/11/23 NodeJs
js表单验证实例讲解
2016/03/31 Javascript
利用Angular+Angular-Ui实现分页(代码加简单)
2017/03/10 Javascript
JS按条件 serialize() 对应标签的使用方法
2017/07/24 Javascript
Vue.js项目模板搭建图文教程
2017/09/20 Javascript
nodejs结合Socket.IO实现的即时通讯功能详解
2018/01/12 NodeJs
在Web关闭页面时发送Ajax请求的实现方法
2019/03/07 Javascript
在Python中用get()方法获取字典键值的教程
2015/05/21 Python
Python文件读取的3种方法及路径转义
2015/06/21 Python
Python抓取百度查询结果的方法
2015/07/08 Python
Python实现代码统计工具(终极篇)
2016/07/04 Python
Python如何获取系统iops示例代码
2016/09/06 Python
Python面向对象基础入门之设置对象属性
2018/12/11 Python
分析运行中的 Python 进程详细解析
2019/06/22 Python
在django中实现页面倒数几秒后自动跳转的例子
2019/08/16 Python
使用Python第三方库pygame写个贪吃蛇小游戏
2020/03/06 Python
python sitk.show()与imageJ结合使用常见的问题
2020/04/20 Python
纯css3实现的竖形无限级导航
2014/12/10 HTML / CSS
办公室文员工作职责
2014/01/31 职场文书
人资专员岗位职责
2014/04/04 职场文书
初三学生评语大全
2014/04/24 职场文书
师范大学生求职信
2014/06/13 职场文书
物联网工程专业推荐信
2014/09/08 职场文书
捐书活动倡议书
2015/04/27 职场文书
2019企业给员工的慰问信
2019/06/24 职场文书
JVM钩子函数的使用场景详解
2021/08/23 Java/Android