Pytorch反向传播中的细节-计算梯度时的默认累加操作


Posted in Python onJune 05, 2021

Pytorch反向传播计算梯度默认累加

今天学习pytorch实现简单的线性回归,发现了pytorch的反向传播时计算梯度采用的累加机制, 于是百度来一下,好多博客都说了累加机制,但是好多都没有说明这个累加机制到底会有啥影响, 所以我趁着自己练习的一个例子正好直观的看一下以及如何解决:

pytorch实现线性回归

先附上试验代码来感受一下:

torch.manual_seed(6)
lr = 0.01   # 学习率
result = []

# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1)) 

# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 这里是迭代过程,为了看pytorch的反向传播计算梯度的细节,我先迭代两次
for iteration in range(2):

    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 反向传播
    loss.backward()
    
    # 这里看一下反向传播计算的梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

上面的代码比较简单,迭代了两次, 看一下计算的梯度结果:

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-122.9075])
b.grad: tensor([-20.9364])

然后我稍微加两行代码, 就是在反向传播上面,我手动添加梯度清零操作的代码,再感受一下结果:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)
for iteration in range(2):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
     if iteration > 0: 
        w.grad.data.zero_()
        b.grad.data.zero_()
    
    # 反向传播
    loss.backward()
    
    # 看一下梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-48.2813])
b.grad: tensor([-8.3831])

从上面可以发现,pytorch在反向传播的时候,确实是默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零。

但是, 如果不进行手动清零的话,会有什么后果呢? 我在这次线性回归试验中,遇到的后果就是loss值反复的震荡不收敛。下面感受一下:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)

for iteration in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
#     print("iteration {}: loss {}".format(iteration, loss))
    result.append(loss)
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
    #if iteration > 0: 
    #    w.grad.data.zero_()
    #    b.grad.data.zero_()
  
    # 反向传播
    loss.backward()
 
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    
    if loss.data.numpy() < 1:
        break
   plt.plot(result)

上面的代码中,我没有进行手动清零,迭代1000次, 把每一次的loss放到来result中, 然后画出图像,感受一下结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

接下来,我把手动清零的注释打开,进行每次迭代之后的手动清零操作,得到的结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

可以看到,这个才是理想中的反向传播求导,然后更新参数后得到的loss值的变化。

总结

这次主要是记录一下,pytorch在进行反向传播计算梯度的时候的累加机制到底是什么样子? 至于为什么采用这种机制,我也搜了一下,大部分给出的结果是这样子的:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

但是如果不想累加的话,可以采用手动清零的方式,只需要在每次迭代时加上即可

w.grad.data.zero_()
b.grad.data.zero_()

另外, 在搜索资料的时候,在一篇博客上看到两个不错的线性回归时pytorch的计算图在这里借用一下:

Pytorch反向传播中的细节-计算梯度时的默认累加操作
Pytorch反向传播中的细节-计算梯度时的默认累加操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python执行子进程实现进程间通信的方法
Jun 02 Python
详解python中字典的循环遍历的两种方式
Feb 07 Python
使用Python对Csv文件操作实例代码
May 12 Python
Python编程实现粒子群算法(PSO)详解
Nov 13 Python
python使用suds调用webservice接口的方法
Jan 03 Python
如何在Django中添加没有微秒的 DateTimeField 属性详解
Jan 30 Python
Python turtle库绘制菱形的3种方式小结
Nov 23 Python
Django多进程滚动日志问题解决方案
Dec 17 Python
pytorch中torch.max和Tensor.view函数用法详解
Jan 03 Python
如何使用flask将模型部署为服务
May 13 Python
Jupyter notebook 更改文件打开的默认路径操作
May 21 Python
python turtle绘图命令及案例
Nov 23 Python
pytorch 梯度NAN异常值的解决方案
Jun 05 #Python
pytorch 权重weight 与 梯度grad 可视化操作
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
You might like
德劲1103二次变频版的打磨
2021/03/02 无线电
留言板翻页的实现详解
2006/10/09 PHP
PHP中使用curl伪造IP的简单方法
2015/08/07 PHP
php计划任务之验证是否有多个进程调用同一个job的方法
2015/12/07 PHP
实例:用 JavaScript 来操作字符串(一些字符串函数)
2007/02/15 Javascript
jquery+json实现的搜索加分页效果
2010/03/31 Javascript
jquery ajax 同步异步的执行示例代码
2010/06/23 Javascript
xss文件页面内容读取(解决)
2010/11/28 Javascript
jquery 获取标签名(tagName)示例代码
2013/07/11 Javascript
javascript与有限状态机详解
2014/05/08 Javascript
js遍历子节点子元素附属性及方法
2014/08/19 Javascript
浅谈javascript中replace()方法
2015/11/10 Javascript
JS基于MSClass和setInterval实现ajax定时采集信息并滚动显示的方法
2016/04/18 Javascript
实用jquery操作表单元素的简单代码
2016/07/04 Javascript
JQuery 动态生成Table表格实例代码
2016/12/02 Javascript
实例分析浏览器中“JavaScript解析器”的工作原理
2016/12/12 Javascript
详解React-Native解决键盘遮挡问题(Keyboard遮挡问题)
2017/07/13 Javascript
Scala解析Json字符串的实例详解
2017/10/11 Javascript
php 解压zip压缩包内容到指定目录的实例
2018/01/23 Javascript
AngularJS 应用模块化的使用
2018/04/04 Javascript
Node.js 使用jade模板引擎的示例
2018/05/11 Javascript
jQuery中$原理实例分析
2018/08/13 jQuery
vue基础之模板和过滤器用法实例分析
2019/03/12 Javascript
javascript实现小型区块链功能
2019/04/03 Javascript
vue实现选中效果
2020/10/07 Javascript
详解Python核心对象类型字符串
2018/02/11 Python
html5的canvas方法使用指南
2014/12/15 HTML / CSS
全面解析HTML5中的标准属性与自定义属性
2016/02/18 HTML / CSS
BNKR中国官网:带你感受澳洲领先潮流时尚
2018/08/21 全球购物
《自然之道》教学反思
2014/02/11 职场文书
关于运动会广播稿300字
2014/10/05 职场文书
2014年食品安全工作总结
2014/12/04 职场文书
优秀教师个人总结
2015/02/11 职场文书
2015新教师教学工作总结
2015/07/22 职场文书
读《推着妈妈去旅行》有感1500字
2019/10/15 职场文书
JavaScript正则表达式实现注册信息校验功能
2022/05/30 Java/Android