Pytorch反向传播中的细节-计算梯度时的默认累加操作


Posted in Python onJune 05, 2021

Pytorch反向传播计算梯度默认累加

今天学习pytorch实现简单的线性回归,发现了pytorch的反向传播时计算梯度采用的累加机制, 于是百度来一下,好多博客都说了累加机制,但是好多都没有说明这个累加机制到底会有啥影响, 所以我趁着自己练习的一个例子正好直观的看一下以及如何解决:

pytorch实现线性回归

先附上试验代码来感受一下:

torch.manual_seed(6)
lr = 0.01   # 学习率
result = []

# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1)) 

# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)
# 这里是迭代过程,为了看pytorch的反向传播计算梯度的细节,我先迭代两次
for iteration in range(2):

    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 反向传播
    loss.backward()
    
    # 这里看一下反向传播计算的梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

上面的代码比较简单,迭代了两次, 看一下计算的梯度结果:

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-122.9075])
b.grad: tensor([-20.9364])

然后我稍微加两行代码, 就是在反向传播上面,我手动添加梯度清零操作的代码,再感受一下结果:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)
for iteration in range(2):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
     if iteration > 0: 
        w.grad.data.zero_()
        b.grad.data.zero_()
    
    # 反向传播
    loss.backward()
    
    # 看一下梯度
    print("w.grad:", w.grad)
    print("b.grad:", b.grad)
    
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

w.grad: tensor([-74.6261])
b.grad: tensor([-12.5532])
w.grad: tensor([-48.2813])
b.grad: tensor([-8.3831])

从上面可以发现,pytorch在反向传播的时候,确实是默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零。

但是, 如果不进行手动清零的话,会有什么后果呢? 我在这次线性回归试验中,遇到的后果就是loss值反复的震荡不收敛。下面感受一下:

torch.manual_seed(6)
lr = 0.01
result = []
# 创建训练数据
x = torch.rand(20, 1) * 10
#print(x)
y = 2 * x + (5 + torch.randn(20, 1)) 
#print(y)
# 构建线性回归函数
w = torch.randn((1), requires_grad=True)
#print(w)
b = torch.zeros((1), requires_grad=True)
#print(b)

for iteration in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pred = torch.add(wx, b)

    # 计算 MSE loss
    loss = (0.5 * (y - y_pred) ** 2).mean()
#     print("iteration {}: loss {}".format(iteration, loss))
    result.append(loss)
    
    # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0
    #if iteration > 0: 
    #    w.grad.data.zero_()
    #    b.grad.data.zero_()
  
    # 反向传播
    loss.backward()
 
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    
    if loss.data.numpy() < 1:
        break
   plt.plot(result)

上面的代码中,我没有进行手动清零,迭代1000次, 把每一次的loss放到来result中, 然后画出图像,感受一下结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

接下来,我把手动清零的注释打开,进行每次迭代之后的手动清零操作,得到的结果:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

可以看到,这个才是理想中的反向传播求导,然后更新参数后得到的loss值的变化。

总结

这次主要是记录一下,pytorch在进行反向传播计算梯度的时候的累加机制到底是什么样子? 至于为什么采用这种机制,我也搜了一下,大部分给出的结果是这样子的:

Pytorch反向传播中的细节-计算梯度时的默认累加操作

但是如果不想累加的话,可以采用手动清零的方式,只需要在每次迭代时加上即可

w.grad.data.zero_()
b.grad.data.zero_()

另外, 在搜索资料的时候,在一篇博客上看到两个不错的线性回归时pytorch的计算图在这里借用一下:

Pytorch反向传播中的细节-计算梯度时的默认累加操作
Pytorch反向传播中的细节-计算梯度时的默认累加操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python内置函数bin() oct()等实现进制转换
Dec 30 Python
Win10下Python环境搭建与配置教程
Nov 18 Python
Python实现随机生成有效手机号码及身份证功能示例
Jun 05 Python
详解Python中的分组函数groupby和itertools)
Jul 11 Python
python 移除字符串尾部的数字方法
Jul 17 Python
python3.6使用urllib完成下载的实例
Dec 19 Python
对Python强大的可变参数传递机制详解
Jun 13 Python
python3常用的数据清洗方法(小结)
Oct 31 Python
Python生成词云的实现代码
Jan 14 Python
python_array[0][0]与array[0,0]的区别详解
Feb 18 Python
使用SQLAlchemy操作数据库表过程解析
Jun 10 Python
Python下使用Trackbar实现绘图板
Oct 27 Python
pytorch 梯度NAN异常值的解决方案
Jun 05 #Python
pytorch 权重weight 与 梯度grad 可视化操作
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
You might like
php下清空字符串中的HTML标签的代码
2010/09/06 PHP
浅析PHP原理之变量分离/引用(Variables Separation)
2013/08/09 PHP
php使用strtotime和date函数判断日期是否有效代码分享
2013/12/25 PHP
推荐十款免费 WordPress 插件
2015/03/24 PHP
使用phpQuery获取数组的实例
2017/03/13 PHP
PHP构造二叉树算法示例
2017/06/21 PHP
PHP Swoole异步MySQL客户端实现方法示例
2019/10/24 PHP
PHP unset函数原理及使用方法解析
2020/08/14 PHP
用jquery来定位
2007/02/20 Javascript
clientX,pageX,offsetX,x,layerX,screenX,offsetLeft区别分析
2010/03/12 Javascript
Js+Flash实现访问剪切板操作
2012/11/20 Javascript
js jquery获取当前元素的兄弟级 上一个 下一个元素
2015/09/01 Javascript
jQuery插件实现静态HTML验证码校验
2015/11/06 Javascript
jQuery自定义滚动条完整实例
2016/01/08 Javascript
js实现加载更多功能实例
2016/10/27 Javascript
Servlet实现文件上传,可多文件上传示例
2016/12/05 Javascript
JS动态修改网页body的背景色实例代码
2017/10/07 Javascript
jQuery实现监听下拉框选中内容发生改变操作示例
2018/07/13 jQuery
微信小程序使用map组件实现检索(定位位置)周边的POI功能示例
2019/01/23 Javascript
在Vue.js中使用TypeScript的方法
2020/03/19 Javascript
Vue向后台传数组数据,springboot接收vue传的数组数据实例
2020/11/12 Javascript
python获得linux下所有挂载点(mount points)的方法
2015/04/29 Python
Python实现ssh批量登录并执行命令
2016/10/25 Python
解决python多行注释引发缩进错误的问题
2019/08/23 Python
详解python中的模块及包导入
2019/08/30 Python
详解Django将秒转换为xx天xx时xx分
2019/09/27 Python
python print 格式化输出,动态指定长度的实现
2020/04/12 Python
Python基于tkinter canvas实现图片裁剪功能
2020/11/05 Python
Scrapy-Redis之RedisSpider与RedisCrawlSpider详解
2020/11/18 Python
python 实现有道翻译功能
2021/02/26 Python
成人继续教育实施方案
2014/03/01 职场文书
优秀班集体先进事迹材料
2014/05/28 职场文书
物流管理专业求职信
2014/05/29 职场文书
小学生勤俭节约演讲稿
2014/08/28 职场文书
国家税务局干部作风整顿整改措施
2014/09/18 职场文书
祝福语集锦:送给闺蜜的生日祝福语
2019/10/08 职场文书