对pytorch中的梯度更新方法详解


Posted in Python onAugust 20, 2019

背景

使用pytorch时,有一个yolov3的bug,我认为涉及到学习率的调整。收集到tencent yolov3和mxnet开源的yolov3,两个优化器中的学习率设置不一样,而且使用GPU数目和batch的更新也不太一样。据此,我简单的了解了下pytorch的权重梯度的更新策略,看看能否一窥究竟。

对代码说明

共三个实验,分布写在代码中的(一)(二)(三)三个地方。运行实验时注释掉其他两个

实验及其结果

实验(三):

不使用zero_grad()时,grad累加在一起,官网是使用accumulate 来表述的,所以不太清楚是取的和还是均值(这两种最有可能)。

不使用zero_grad()时,是直接叠加add的方式累加的。

tensor([[[ 1., 1.],……torch.Size([2, 2, 2])
0 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 
tensor([[[ 2., 2.],…… torch.Size([2, 2, 2])
1 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 
tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])
2 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *

实验(二):

单卡上不同的batchsize对梯度是怎么作用的。 mini-batch SGD中的batch是加快训练,同时保持一定的噪声。但设置不同的batchsize的权重的梯度是怎么计算的呢。

设置运行实验(二),可以看到结果如下:所以单卡batchsize计算梯度是取均值的

tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])

实验(一):

多gpu情况下,梯度怎么合并在一起的。

在《training imagenet in 1 hours》中提到grad是allreduce的,是累加的形式。但是当设置g=2,实验一运行时,结果也是取均值的,类同于实验(二)

tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])

实验代码

import torch
import torch.nn as nn
from torch.autograd import Variable


class model(nn.Module):
 def __init__(self, w):
  super(model, self).__init__()
  self.w = w

 def forward(self, xx):
  b, c, _, _ = xx.shape
  # extra = xx.device.index + 1 ## 实验(一)
  y = xx.reshape(b, -1).mm(self.w.cuda(xx.device).reshape(-1, 2) * extra)
  return y.reshape(len(xx), -1)


g = 1
x = Variable(torch.ones(2, 1, 2, 2))
# x[1] += 1 ## 实验(二)
w = Variable(torch.ones(2, 2, 2) * 2, requires_grad=True)
# optim = torch.optim.SGD({'params': x},
lr = 0.01
momentum = 0.9
M = model(w)

M = torch.nn.DataParallel(M, device_ids=range(g))

for i in range(3):
 b = len(x)
 z = M(x)
 zz = z.sum(1)
 l = (zz - Variable(torch.ones(b).cuda())).mean()
 # zz.backward(Variable(torch.ones(b).cuda()))
 l.backward()
 print(w.grad, w.grad.shape)
 # w.grad.zero_() ## 实验(三)
 print(i, b, '* * ' * 20)

以上这篇对pytorch中的梯度更新方法详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中for循环和while循环的基本使用方法
Aug 21 Python
Python中的os.path路径模块中的操作方法总结
Jul 07 Python
python爬虫获取京东手机图片的图文教程
Dec 29 Python
Python SQLite3简介
Feb 22 Python
浅谈python中字典append 到list 后值的改变问题
May 04 Python
使用Django启动命令行及执行脚本的方法
May 29 Python
示例详解Python3 or Python2 两者之间的差异
Aug 23 Python
Python sorted函数详解(高级篇)
Sep 18 Python
Python global全局变量函数详解
Sep 18 Python
python绘制BA无标度网络示例代码
Nov 21 Python
深度学习入门之Pytorch 数据增强的实现
Feb 26 Python
Python Opencv 通过轨迹(跟踪)栏实现更改整张图像的背景颜色
Mar 09 Python
PyTorch: 梯度下降及反向传播的实例详解
Aug 20 #Python
python爬虫 urllib模块发起post请求过程解析
Aug 20 #Python
pytorch 加载(.pth)格式的模型实例
Aug 20 #Python
python multiprocessing模块用法及原理介绍
Aug 20 #Python
python 并发编程 阻塞IO模型原理解析
Aug 20 #Python
PyTorch中常用的激活函数的方法示例
Aug 20 #Python
Pytorch抽取网络层的Feature Map(Vgg)实例
Aug 20 #Python
You might like
PHP如何得到当前页和上一页的地址?
2006/11/27 PHP
php下MYSQL limit的优化
2008/01/10 PHP
PHP curl模拟浏览器采集阿里巴巴的实现代码
2011/04/20 PHP
浅谈web上存漏洞及原理分析、防范方法(文件名检测漏洞)
2013/06/29 PHP
原生js和jQuery随意改变div属性style的名称和值
2014/10/22 Javascript
详解javascript实现瀑布流绝对式布局
2016/01/29 Javascript
javascript基本语法
2016/05/31 Javascript
vue mounted组件的使用
2018/06/18 Javascript
Vuejs 实现简易 todoList 功能 与 组件实例代码
2018/09/10 Javascript
Layui table field初始化加载时进行隐藏的方法
2019/09/19 Javascript
vue中使用极验验证码的方法(附demo)
2019/12/04 Javascript
原生javascript的ajax请求及后台PHP响应操作示例
2020/02/24 Javascript
原生JS实现多条件筛选
2020/08/19 Javascript
详解vue-cli项目在IE浏览器打开报错解决方法
2020/12/10 Vue.js
[02:17]DOTA2亚洲邀请赛 RAVE战队出场宣传片
2015/02/07 DOTA
Python列表append和+的区别浅析
2015/02/02 Python
在Python中使用pngquant压缩png图片的教程
2015/04/09 Python
Python输出汉字字库及将文字转换为图片的方法
2016/06/04 Python
用pickle存储Python的原生对象方法
2017/04/28 Python
kaggle+mnist实现手写字体识别
2018/07/26 Python
Python自定义一个类实现字典dict功能的方法
2019/01/19 Python
解决webdriver.Chrome()报错:Message:'chromedriver' executable needs to be in Path
2019/06/12 Python
Python图片的横坐标汉字实例
2019/12/04 Python
Ranorex通过Python将报告发送到邮箱的方法
2020/01/12 Python
浅谈pandas.cut与pandas.qcut的使用方法及区别
2020/03/03 Python
Python3基于plotly模块保存图片表格
2020/08/03 Python
html5将图片转换成base64的实例代码
2016/09/21 HTML / CSS
小米俄罗斯授权商店:Xiaomi俄罗斯
2019/12/08 全球购物
哈萨克斯坦移动和数字技术在线商店:SatelOnline.kz
2020/09/04 全球购物
销售经理助理岗位职责
2015/04/13 职场文书
保留意见审计报告
2015/06/05 职场文书
三八妇女节新闻稿
2015/07/17 职场文书
django中websocket的具体使用
2022/01/22 Python
Python实现日志实时监测的示例详解
2022/04/06 Python
Java 使用类型为Object的变量指向任意类型的对象
2022/04/13 Java/Android
关于MySQL中explain工具的使用
2023/05/08 MySQL