pytorch中Schedule与warmup_steps的用法说明


Posted in Python onMay 24, 2021

1. lr_scheduler相关

lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)

其中args.warmup_steps可以认为是耐心系数

num_train_optimization_steps为模型参数的总更新次数

一般来说:

num_train_optimization_steps = int(total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)

Schedule用来调节学习率,拿线性变换调整来说,下面代码中,step是当前迭代次数。

def lr_lambda(self, step):
        # 线性变换,返回的是某个数值x,然后返回到类LambdaLR中,最终返回old_lr*x
        if step < self.warmup_steps: # 增大学习率
            return float(step) / float(max(1, self.warmup_steps))
        # 减小学习率
        return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))

在实际运行中,lr_scheduler.step()先将lr初始化为0. 在第一次参数更新时,此时step=1,lr由0变为初始值initial_lr;在第二次更新时,step=2,上面代码中生成某个实数alpha,新的lr=initial_lr *alpha;在第三次更新时,新的lr是在initial_lr基础上生成,即新的lr=initial_lr *alpha。

其中warmup_steps可以认为是lr调整的耐心系数。

由于有warmup_steps存在,lr先慢慢增加,超过warmup_steps时,lr再慢慢减小。

在实际中,由于训练刚开始时,训练数据计算出的grad可能与期望方向相反,所以此时采用较小的lr,随着迭代次数增加,lr线性增大,增长率为1/warmup_steps;迭代次数等于warmup_steps时,学习率为初始设定的学习率;迭代次数超过warmup_steps时,学习率逐步衰减,衰减率为1/(total-warmup_steps),再进行微调。

2. gradient_accumulation_steps相关

gradient_accumulation_steps通过累计梯度来解决本地显存不足问题。

假设原来的batch_size=6,样本总量为24,gradient_accumulation_steps=2

那么参数更新次数=24/6=4

现在,减小batch_size=6/2=3,参数更新次数不变=24/3/2=4

在梯度反传时,每gradient_accumulation_steps次进行一次梯度更新,之前照常利用loss.backward()计算梯度。

补充:pytorch学习笔记 -optimizer.step()和scheduler.step()

optimizer.step()和scheduler.step()的区别

optimizer.step()通常用在每个mini-batch之中,而scheduler.step()通常用在epoch里面,但是不绝对,可以根据具体的需求来做。只有用了optimizer.step(),模型才会更新,而scheduler.step()是对lr进行调整。

通常我们有

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
model = net.train(model, loss_function, optimizer, scheduler, num_epochs = 100)

在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次。

所以如果scheduler.step()是放在mini-batch里面,那么step_size指的是经过这么多次迭代,学习率改变一次。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中映射类型(字典)操作符的概念和使用
Aug 19 Python
Sanic框架Cookies操作示例
Jul 17 Python
Python二进制串转换为通用字符串的方法
Jul 23 Python
Python拼接微信好友头像大图的实现方法
Aug 01 Python
centos6.8安装python3.7无法import _ssl的解决方法
Sep 17 Python
Django中使用Celery的方法示例
Nov 29 Python
对YOLOv3模型调用时候的python接口详解
Aug 26 Python
python编写简单端口扫描器
Sep 04 Python
Django 路由层URLconf的实现
Dec 30 Python
Python更新所有已安装包的操作
Feb 13 Python
Python sublime安装及配置过程详解
Jun 29 Python
Python使用Kubernetes API访问集群
May 30 Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
You might like
php中$_GET与$_POST过滤sql注入的方法
2014/11/03 PHP
php获取英文姓名首字母的方法
2015/07/13 PHP
完美利用Yii2微信后台开发的系列总结
2016/07/18 PHP
php使用curl实现ftp文件下载功能
2017/05/16 PHP
PHP设计模式之工厂模式(Factory Pattern)的讲解
2019/03/21 PHP
判断是否输入完毕再激活提交按钮
2006/06/26 Javascript
json简单介绍
2008/06/10 Javascript
JavaScript面向对象设计二 构造函数模式
2011/12/20 Javascript
JS调用CS里的带参方法实例
2013/08/01 Javascript
js获取当前页面的url网址信息
2014/06/12 Javascript
js实现简单选项卡与自动切换效果的方法
2015/04/10 Javascript
JQuery radio(单选按钮)操作方法汇总
2015/04/15 Javascript
JQuery显示隐藏DIV的方法及代码实例
2015/04/16 Javascript
jQuery的基本概念与高级编程
2015/05/14 Javascript
JavaScript中利用各种循环进行遍历的方式总结
2015/11/10 Javascript
常用js,css文件统一加载方法(推荐) 并在加载之后调用回调函数
2016/09/23 Javascript
AngularJS入门教程之数据绑定用法示例
2016/11/01 Javascript
js 递归和定时器的实例解析
2017/02/03 Javascript
详解node如何让一个端口同时支持https与http
2017/07/04 Javascript
详解vue项目首页加载速度优化
2017/10/18 Javascript
element ui里dialog关闭后清除验证条件方法
2018/02/26 Javascript
Angular5整合富文本编辑器TinyMCE的方法(汉化+上传)
2020/05/26 Javascript
js实现Element中input组件的部分功能并封装成组件(实例代码)
2021/03/02 Javascript
利用Python脚本在Nginx和uwsgi上部署MoinMoin的教程
2015/05/05 Python
Python的Flask框架中的Jinja2模板引擎学习教程
2016/06/30 Python
Python中str.format()详解
2017/03/12 Python
一个可以套路别人的python小程序实例代码
2019/04/09 Python
详解如何设置Python环境变量?
2019/05/13 Python
Python手绘可视化工具cutecharts使用实例
2019/12/05 Python
Python 3.8 新功能来一波(大部分人都不知道)
2020/03/11 Python
Python 处理日期时间的Arrow库使用
2020/08/18 Python
python基于爬虫+django,打造个性化API接口
2021/01/21 Python
中学生英语演讲稿
2014/04/26 职场文书
如何将JavaScript将数组转为树形结构
2021/06/02 Javascript
Windows下redis下载、redis安装及使用教程
2021/06/02 Redis
Nginx工作模式及代理配置的使用细节
2022/03/21 Servers