PyTorch 如何检查模型梯度是否可导


Posted in Python onJune 05, 2021

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

PyTorch 如何检查模型梯度是否可导

PyTorch 如何检查模型梯度是否可导

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

PyTorch 如何检查模型梯度是否可导

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

PyTorch 如何检查模型梯度是否可导PyTorch 如何检查模型梯度是否可导

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python urllib模块urlopen()与urlretrieve()详解
Nov 01 Python
python实现监控linux性能及进程消耗性能的方法
Jul 25 Python
理解Python中的With语句
Feb 02 Python
python 换位密码算法的实例详解
Jul 19 Python
解决django前后端分离csrf验证的问题
Feb 03 Python
python3实现字符串操作的实例代码
Apr 16 Python
Django框架模板文件使用及模板文件加载顺序分析
May 23 Python
Python turtle库绘制菱形的3种方式小结
Nov 23 Python
使用Python测试Ping主机IP和某端口是否开放的实例
Dec 17 Python
Python关于__name__属性的含义和作用详解
Feb 19 Python
Python批量删除mysql中千万级大量数据的脚本分享
Dec 03 Python
Biblibili视频投稿接口分析并以Python实现自动投稿功能
Feb 05 Python
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
You might like
php xml常用函数的集合(比较详细)
2013/06/06 PHP
Thinkphp5.0自动生成模块及目录的方法详解
2017/04/17 PHP
PHP实现的折半查找算法示例
2017/12/19 PHP
Jquery方式获取iframe页面中的 Dom元素
2014/05/07 Javascript
Javascript基础知识(三)BOM,DOM总结
2014/09/29 Javascript
innerHTML动态添加html代码和脚本兼容多个浏览器
2014/10/11 Javascript
移除AngularJS下URL中的#字符的方法
2015/06/19 Javascript
JavaScript与jQuery实现的闪烁输入效果
2016/02/18 Javascript
EasyUI Pagination 分页的两种做法小结
2016/07/09 Javascript
AngularJs基本特性解析(一)
2016/07/21 Javascript
用js实现简单算法的实例代码
2016/09/24 Javascript
jquery实现下拉框多选方法介绍
2017/01/03 Javascript
JS对象深度克隆实例分析
2017/03/16 Javascript
详解JS函数stack size计算方法
2018/06/18 Javascript
详解微信小程序框架wepy踩坑记录(与vue对比)
2019/03/12 Javascript
原生js实现无缝轮播图
2020/01/11 Javascript
vue中解决拖拽改变存在iframe的div大小时卡顿问题
2020/07/22 Javascript
python中enumerate函数用法实例分析
2015/05/20 Python
Python操作Oracle数据库的简单方法和封装类实例
2018/05/07 Python
python 使用 requests 模块发送http请求 的方法
2018/12/09 Python
Tensorflow累加的实现案例
2020/02/05 Python
python操作docx写入内容,并控制文本的字体颜色
2020/02/13 Python
css3实现顶部社会化分享按钮示例
2014/05/06 HTML / CSS
5分钟弄清楚html5的drag and drop(小结)
2019/04/10 HTML / CSS
HTML5 form标签之解放表单验证、增加文件上传、集成拖放的使用方法
2013/04/24 HTML / CSS
中医药大学市场营销专业自荐信
2013/09/29 职场文书
2014年端午节活动方案
2014/03/11 职场文书
校庆接待方案
2014/03/18 职场文书
祖国在我心中演讲稿400字
2014/05/04 职场文书
新闻专业毕业生求职信
2014/08/08 职场文书
校园广播稿100字
2014/10/06 职场文书
幼儿园班级工作总结2015
2015/05/25 职场文书
房产遗嘱范本
2015/08/06 职场文书
队名及霸气口号大全
2015/12/25 职场文书
《最后一头战象》读后感:动物也有感情
2020/01/02 职场文书
python热力图实现的完整实例
2022/06/25 Python