PyTorch的Optimizer训练工具的实现


Posted in Python onAugust 18, 2019

torch.optim 是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法。

使用 torch.optim,必须构造一个 optimizer 对象。这个对象能保存当前的参数状态并且基于计算梯度更新参数。

例如:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr = 0.0001)

构造方法

Optimizer 的 __init__ 函数接收两个参数:第一个是需要被优化的参数,其形式必须是 Tensor 或者 dict;第二个是优化选项,包括学习率、衰减率等。

被优化的参数一般是 model.parameters(),当有特殊需求时可以手动写一个 dict 来作为输入。

例如:

optim.SGD([
  {'params': model.base.parameters()},
  {'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)

这样 model.base 或者说大部分的参数使用 1e-2 的学习率,而 model.classifier 的参数使用 1e-3 的学习率,并且 0.9 的 momentum 被用于所有的参数。

梯度控制

在进行反向传播之前,必须要用 zero_grad() 清空梯度。具体的方法是遍历 self.param_groups 中全部参数,根据 grad 属性做清除。

例如:

for input, target in dataset:
  def closure():
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    return loss
  optimizer.step(closure)

调整学习率

lr_scheduler 用于在训练过程中根据轮次灵活调控学习率。调整学习率的方法有很多种,但是其使用方法是大致相同的:用一个 Schedule 把原始 Optimizer 装饰上,然后再输入一些相关参数,然后用这个 Schedule 做 step()。

比如以 LambdaLR 举例:

lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
for epoch in range(100):
 train(...)
 validate(...)
 scheduler.step()

上面用了两种优化器

优化方法

optim 库中实现的算法包括 Adadelta、Adagrad、Adam、基于离散张量的 Adam、基于 ∞ \infty∞ 范式的 Adam(Adamax)、Averaged SGD、L-BFGS、RMSProp、resilient BP、基于 Nesterov 的 SGD 算法。

以 SGD 举例:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()

其它方法的使用也一样:

opt_Adam = torch.optim.Adam(net_Adam.parameters(), lr=0.1, betas=(0.9, 0.99)
opt_RMSprop = torch.optim.RMSprop(net_RMSprop.parameters(), lr=0.1, alpha=0.9)
...
...

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的print用法示例
Feb 11 Python
Python实现多线程抓取妹子图
Aug 08 Python
浅析python实现scrapy定时执行爬虫
Mar 04 Python
Python将DataFrame的某一列作为index的方法
Apr 08 Python
Python使用numpy模块创建数组操作示例
Jun 20 Python
在Python中增加和插入元素的示例
Nov 01 Python
Pandas-Cookbook 时间戳处理方式
Dec 07 Python
python实现梯度下降和逻辑回归
Mar 24 Python
简单了解django处理跨域请求最佳解决方案
Mar 25 Python
Jupyter notebook 远程配置及SSL加密教程
Apr 14 Python
使用tensorflow根据输入更改tensor shape
Jun 23 Python
Python生成九宫格图片的示例代码
Apr 14 Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
You might like
简单了解WordPress开发中update_option()函数的用法
2016/01/11 PHP
简单谈谈php浮点数精确运算
2016/03/10 PHP
Yii中的cookie的发送和读取
2016/07/27 PHP
Thinkphp框架中D方法与M方法的区别
2016/12/23 PHP
Laravel框架查询构造器简单示例
2019/05/08 PHP
Using the TextRange Object
2006/10/14 Javascript
JS类定义原型方法的两种实现的区别评论很多
2007/09/12 Javascript
JS是否可以跨文件同时控制多个iframe页面的应用技巧
2007/12/16 Javascript
IE与Firefox下javascript getyear年份的兼容性写法
2007/12/20 Javascript
JavaScript 一行代码,轻松搞定浮动快捷留言-V2升级版
2010/04/02 Javascript
超简单的jquery的AJAX用法
2010/05/10 Javascript
复制Input内容的js代码_支持所有浏览器,修正了Firefox3.5以上的问题
2010/06/21 Javascript
js两行代码按指定格式输出日期时间
2011/10/21 Javascript
seaJs的模块定义和模块加载浅析
2014/06/06 Javascript
javascript中eval函数用法分析
2015/04/25 Javascript
JavaScript关于提高网站性能的几点建议(一)
2016/07/24 Javascript
JavaScript之cookie技术详解
2016/11/18 Javascript
jQuery实现动态文字搜索功能
2017/01/05 Javascript
折叠菜单及选择器的运用
2017/02/03 Javascript
搭建简单的nodejs http服务器详解
2017/03/09 NodeJs
基于layPage插件实现两种分页方式浅析
2019/07/27 Javascript
python合并文本文件示例
2014/02/07 Python
Python Paramiko模块的安装与使用详解
2016/11/18 Python
Python异常处理知识点总结
2019/02/18 Python
PyCharm中代码字体大小调整方法
2019/07/29 Python
python爬虫 基于requests模块的get请求实现详解
2019/08/20 Python
Anaconda3中的Jupyter notebook添加目录插件的实现
2020/05/18 Python
Elasticsearch py客户端库安装及使用方法解析
2020/09/14 Python
CSS3绘制超炫的上下起伏波动进度加载动画
2016/04/21 HTML / CSS
欧洲领先的电子和电信零售商和服务提供商:Currys PC World Business
2017/12/05 全球购物
伦敦眼门票在线预订:London Eye
2018/05/31 全球购物
抽象方法、抽象类怎样声明
2014/10/25 面试题
统计学教授推荐信
2014/09/18 职场文书
国博复兴之路观后感
2015/06/02 职场文书
python实现过滤敏感词
2021/05/08 Python
pytorch 中autograd.grad()函数的用法说明
2021/05/12 Python