关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)


Posted in Python onFebruary 20, 2020

torch.optim的灵活使用详解

1. 基本用法:

要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项,

例如学习速率,重量衰减值等。

注:如果要把model放在GPU中,需要在构建一个Optimizer之前就执行model.cuda(),确保优化器里面的参数也是在GPU中。

例子:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)

2. 灵活的设置各层的学习率

将model中需要进行BP的层的参数送到torch.optim中,这些层不一定是连续的。

这个时候,Optimizer的参数不是一个可迭代的变量,而是一个可迭代的字典

(字典的key必须包含'params'(查看源码可以得知optimizer通过'params'访问parameters),

其他的key就是optimizer可以接受的,比如说'lr','weight_decay'),可以将这些字典构成一个list,

这样就是一个可迭代的字典了。

注:这个时候,可以在optimizer设置选项作为关键字参数传递,这时它们将被认为是默认值(当字典里面没有这个关键字参数key-value对时,就使用这个默认的参数)

This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

例子:

optimizer = SGD([
        {'params': model.features12.parameters(), 'lr': 1e-2},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

上面创建的optim.SGD类型的Optimizer,lr默认值为1e-1,momentum默认值为0.9。features12的参数学习率为1e-2。

灵活更改各层的学习率

torch.optim.optimizer.Optimizer的初始化函数如下:

__init__(self, params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

params (iterable): iterable of parameters to optimize or dicts defining parameter groups (params可以是可迭代的参数,或者一个定义参数组的字典,如上所示,字典的键值包括:params,lr,momentum,dampening,weight_decay,nesterov)

想要改变各层的学习率,可以访问optimizer的param_groups属性。type(optimizer.param_groups) -> list

optimizer.param_groups[0].keys()
Out[21]: ['dampening', 'nesterov', 'params', 'lr', 'weight_decay', 'momentum']

因此,想要更改某层参数的学习率,可以访问optimizer.param_groups,指定某个索引更改'lr'参数就可以。

def adjust_learning_rate(optimizer, decay_rate=0.9):
  for para in optimizer.param_groups:
    para['lr'] = para['lr']*decay_rate

重写torch.optim,加上L1正则

查看torch.optim.SGD等Optimizer的源码,发现没有L1正则的选项,而L1正则更容易得到稀疏解。

这个时候,可以更改/home/smiles/anaconda2/lib/python2.7/site-packages/torch/optim/sgd.py文件,模拟L2正则化的操作。

L1正则化求导如下:

dw = 1 * sign(w)

更改后的sgd.py如下:

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
  def __init__(self, params, lr=required, momentum=0, dampening=0,
         weight_decay1=0, weight_decay2=0, nesterov=False):
    defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
            weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
    if nesterov and (momentum <= 0 or dampening != 0):
      raise ValueError("Nesterov momentum requires a momentum and zero dampening")
    super(SGD, self).__init__(params, defaults)

  def __setstate__(self, state):
    super(SGD, self).__setstate__(state)
    for group in self.param_groups:
      group.setdefault('nesterov', False)

  def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    for group in self.param_groups:
      weight_decay1 = group['weight_decay1']
      weight_decay2 = group['weight_decay2']
      momentum = group['momentum']
      dampening = group['dampening']
      nesterov = group['nesterov']

      for p in group['params']:
        if p.grad is None:
          continue
        d_p = p.grad.data
        if weight_decay1 != 0:
          d_p.add_(weight_decay1, torch.sign(p.data))
        if weight_decay2 != 0:
          d_p.add_(weight_decay2, p.data)
        if momentum != 0:
          param_state = self.state[p]
          if 'momentum_buffer' not in param_state:
            buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
            buf.mul_(momentum).add_(d_p)
          else:
            buf = param_state['momentum_buffer']
            buf.mul_(momentum).add_(1 - dampening, d_p)
          if nesterov:
            d_p = d_p.add(momentum, buf)
          else:
            d_p = buf

        p.data.add_(-group['lr'], d_p)

    return loss

一个使用的例子:

optimizer = SGD([
        {'params': model.features12.parameters()},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

以上这篇关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现数组插入新元素的方法
May 22 Python
Linux下将Python的Django项目部署到Apache服务器
Dec 24 Python
几种实用的pythonic语法实例代码
Feb 24 Python
python实现简单登陆流程的方法
Apr 22 Python
详解Django+Uwsgi+Nginx的生产环境部署
Jun 25 Python
OpenCV+python手势识别框架和实例讲解
Aug 03 Python
pyqt5对用qt designer设计的窗体实现弹出子窗口的示例
Jun 19 Python
Pytorch根据layers的name冻结训练方式
Jan 06 Python
如何在python开发工具PyCharm中搭建QtPy环境(教程详解)
Feb 04 Python
Python中的整除和取模实例
Jun 03 Python
Python常用类型转换实现代码实例
Jul 28 Python
python如何利用cv2模块读取显示保存图片
Jun 04 Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
You might like
phpmyadmin MySQL 加密配置方法
2009/07/05 PHP
PHP文件上传主要代码讲解
2013/09/30 PHP
PHP 正则判断中文UTF-8或GBK的思路及具体实现
2013/11/26 PHP
IE JS编程需注意的内存释放问题
2009/06/23 Javascript
javascript 操作cookies及正确使用cookies的属性
2009/10/15 Javascript
javascript Array数组对象的扩展函数代码
2010/05/22 Javascript
jquery弹出关闭遮罩层实例
2013/08/06 Javascript
JavaScript插件化开发教程 (二)
2015/01/27 Javascript
javascript基本包装类型介绍
2015/04/10 Javascript
每天一篇javascript学习小结(RegExp对象)
2015/11/17 Javascript
让DIV的滚动条自动滚动到最底部的3种方法(推荐)
2016/09/24 Javascript
Node.js编写CLI的实例详解
2017/05/17 Javascript
使用cropper.js裁剪头像的实例代码
2017/09/29 Javascript
详解webpack打包第三方类库的正确姿势
2018/10/20 Javascript
读懂CommonJS的模块加载
2019/04/19 Javascript
layui实现数据表格table分页功能(ajax异步)
2019/07/27 Javascript
JS实现“全选”和&quot;全不选&quot;功能代码实例
2020/02/06 Javascript
JavaScript接口实现方法实例分析
2020/05/16 Javascript
vant-ui框架的一个bug(解决切换后onload不触发)
2020/11/11 Javascript
vue+Element-ui实现分页效果
2020/11/15 Javascript
[58:32]EG vs Liquid 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
python 中文乱码问题深入分析
2011/03/13 Python
Python使用pymysql小技巧
2017/06/04 Python
python 中random模块的常用方法总结
2017/07/08 Python
浅谈python连续赋值可能引发的错误
2018/11/10 Python
如何导出python安装的所有模块名称和版本号到文件中
2020/06/05 Python
如何在Windows中安装多个python解释器
2020/06/16 Python
python中slice参数过长的处理方法及实例
2020/12/15 Python
从零实现一个自定义html5播放器的示例代码
2017/08/01 HTML / CSS
存储过程的优缺点是什么
2015/01/10 面试题
英语生日邀请函
2014/01/23 职场文书
自我鉴定总结
2014/03/24 职场文书
幼儿园中班教师个人总结
2015/02/05 职场文书
小学主题班会教案
2015/08/17 职场文书
利用For循环遍历Python字典的三种方法实例
2022/03/25 Python
Jmerte 分布式压测及分布式压测配置
2022/04/30 Java/Android