PyTorch的Optimizer训练工具的实现


Posted in Python onAugust 18, 2019

torch.optim 是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法。

使用 torch.optim,必须构造一个 optimizer 对象。这个对象能保存当前的参数状态并且基于计算梯度更新参数。

例如:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)

构造方法

Optimizer 的 __init__ 函数接收两个参数:第一个是需要被优化的参数,其形式必须是 Tensor 或者 dict;第二个是优化选项,包括学习率、衰减率等。

被优化的参数一般是 model.parameters(),当有特殊需求时可以手动写一个 dict 来作为输入。

例如:

optim.SGD([
  {'params': model.base.parameters()},
  {'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)

这样 model.base 或者说大部分的参数使用 1e-2 的学习率,而 model.classifier 的参数使用 1e-3 的学习率,并且 0.9 的 momentum 被用于所有的参数。

梯度控制

在进行反向传播之前,必须要用 zero_grad() 清空梯度。具体的方法是遍历 self.param_groups 中全部参数,根据 grad 属性做清除。

例如:

for input, target in dataset:
  def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss
  optimizer.step(closure)

调整学习率

lr_scheduler 用于在训练过程中根据轮次灵活调控学习率。调整学习率的方法有很多种,但是其使用方法是大致相同的:用一个 Schedule 把原始 Optimizer 装饰上,然后再输入一些相关参数,然后用这个 Schedule 做 step()。

比如以 LambdaLR 举例:

lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
for epoch in range(100):
 train(...)
 validate(...)
 scheduler.step()

上面用了两种优化器

优化方法

optim 库中实现的算法包括 Adadelta、Adagrad、Adam、基于离散张量的 Adam、基于 ∞ \infty∞ 范式的 Adam(Adamax)、Averaged SGD、L-BFGS、RMSProp、resilient BP、基于 Nesterov 的 SGD 算法。

以 SGD 举例:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

其它方法的使用也一样:

opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=0.1, betas=(0.9, 0.99)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=0.1, alpha=0.9)
...
...

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python兔子毒药问题实例分析
Mar 05 Python
Python脚本实现集群检测和管理功能
Mar 06 Python
利用python求相邻数的方法示例
Aug 18 Python
为什么Python中没有"a++"这种写法
Nov 27 Python
Python爬虫实现验证码登录代码实例
May 10 Python
python模拟菜刀反弹shell绕过限制【推荐】
Jun 25 Python
Python中list的交、并、差集获取方法示例
Aug 01 Python
PyQt+socket实现远程操作服务器的方法示例
Aug 22 Python
Keras实现DenseNet结构操作
Jul 06 Python
Django rest framework分页接口实现原理解析
Aug 21 Python
Python数据分析之pandas读取数据
Jun 02 Python
在python中读取和写入CSV文件详情
Jun 28 Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
You might like
php 分页原理详解
2009/08/21 PHP
php小偷相关截取函数备忘
2010/11/28 PHP
php数据库抽象层 PDO
2011/05/07 PHP
php中模拟POST传递数据的两种方法分享
2011/09/16 PHP
php中sql注入漏洞示例 sql注入漏洞修复
2014/01/24 PHP
php微信开发接入
2016/08/27 PHP
PHP 返回数组后处理方法(开户成功后弹窗提示)
2017/07/03 PHP
jQuery中live方法的重复绑定说明
2011/10/21 Javascript
JavaScript中的变量作用域介绍
2014/12/31 Javascript
Bootstrap实现响应式导航栏效果
2015/12/28 Javascript
jQuery EasyUI之DataGrid使用实例详解
2016/01/04 Javascript
NODE.JS跨域问题的完美解决方案
2016/10/20 Javascript
使用angular帮你实现拖拽的示例
2017/07/05 Javascript
利用JavaScript实现栈的数据结构示例代码
2017/08/02 Javascript
JQuery元素快速查找与操作
2018/04/22 jQuery
微信小程序语音同步智能识别的实现案例代码解析
2020/05/29 Javascript
如何在postman中添加cookie信息步骤解析
2020/06/30 Javascript
Vue Render函数原理及代码实例解析
2020/07/30 Javascript
JavaScript实现多球运动效果
2020/09/07 Javascript
vue 使用 v-model 双向绑定父子组件的值遇见的问题及解决方案
2021/03/01 Vue.js
[04:10]2018年度CS GO玩家最喜爱的主播-完美盛典
2018/12/16 DOTA
在win和Linux系统中python命令行运行的不同
2016/07/03 Python
Pycharm学习教程(5) Python快捷键相关设置
2017/05/03 Python
Python算法之图的遍历
2017/11/16 Python
Laravel框架表单验证格式化输出的方法
2019/09/25 Python
python实现取余操作的简单实例
2020/08/16 Python
python获取时间戳的实现示例(10位和13位)
2020/09/23 Python
美国电视购物:QVC
2017/02/06 全球购物
ALDO英国官网:加拿大女鞋品牌
2018/02/19 全球购物
股权转让协议书
2014/12/07 职场文书
2015年六一儿童节活动总结
2015/02/11 职场文书
公司开业的祝贺语大全(60条)
2019/07/05 职场文书
经典格言警句:没有热忱,世间便无进步
2019/11/13 职场文书
Go中的条件语句Switch示例详解
2021/08/23 Golang
Python中异常处理用法
2021/11/27 Python
Python中使用Opencv开发停车位计数器功能
2022/04/04 Python