关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)


Posted in Python onFebruary 20, 2020

torch.optim的灵活使用详解

1. 基本用法:

要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项,

例如学习速率,重量衰减值等。

注:如果要把model放在GPU中,需要在构建一个Optimizer之前就执行model.cuda(),确保优化器里面的参数也是在GPU中。

例子:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)

2. 灵活的设置各层的学习率

将model中需要进行BP的层的参数送到torch.optim中,这些层不一定是连续的。

这个时候,Optimizer的参数不是一个可迭代的变量,而是一个可迭代的字典

(字典的key必须包含'params'(查看源码可以得知optimizer通过'params'访问parameters),

其他的key就是optimizer可以接受的,比如说'lr','weight_decay'),可以将这些字典构成一个list,

这样就是一个可迭代的字典了。

注:这个时候,可以在optimizer设置选项作为关键字参数传递,这时它们将被认为是默认值(当字典里面没有这个关键字参数key-value对时,就使用这个默认的参数)

This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

例子:

optimizer = SGD([
        {'params': model.features12.parameters(), 'lr': 1e-2},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

上面创建的optim.SGD类型的Optimizer,lr默认值为1e-1,momentum默认值为0.9。features12的参数学习率为1e-2。

灵活更改各层的学习率

torch.optim.optimizer.Optimizer的初始化函数如下:

__init__(self, params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

params (iterable): iterable of parameters to optimize or dicts defining parameter groups (params可以是可迭代的参数,或者一个定义参数组的字典,如上所示,字典的键值包括:params,lr,momentum,dampening,weight_decay,nesterov)

想要改变各层的学习率,可以访问optimizer的param_groups属性。type(optimizer.param_groups) -> list

optimizer.param_groups[0].keys()
Out[21]: ['dampening', 'nesterov', 'params', 'lr', 'weight_decay', 'momentum']

因此,想要更改某层参数的学习率,可以访问optimizer.param_groups,指定某个索引更改'lr'参数就可以。

def adjust_learning_rate(optimizer, decay_rate=0.9):
  for para in optimizer.param_groups:
    para['lr'] = para['lr']*decay_rate

重写torch.optim,加上L1正则

查看torch.optim.SGD等Optimizer的源码,发现没有L1正则的选项,而L1正则更容易得到稀疏解。

这个时候,可以更改/home/smiles/anaconda2/lib/python2.7/site-packages/torch/optim/sgd.py文件,模拟L2正则化的操作。

L1正则化求导如下:

dw = 1 * sign(w)

更改后的sgd.py如下:

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
  def __init__(self, params, lr=required, momentum=0, dampening=0,
         weight_decay1=0, weight_decay2=0, nesterov=False):
    defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
            weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
    if nesterov and (momentum <= 0 or dampening != 0):
      raise ValueError("Nesterov momentum requires a momentum and zero dampening")
    super(SGD, self).__init__(params, defaults)

  def __setstate__(self, state):
    super(SGD, self).__setstate__(state)
    for group in self.param_groups:
      group.setdefault('nesterov', False)

  def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    for group in self.param_groups:
      weight_decay1 = group['weight_decay1']
      weight_decay2 = group['weight_decay2']
      momentum = group['momentum']
      dampening = group['dampening']
      nesterov = group['nesterov']

      for p in group['params']:
        if p.grad is None:
          continue
        d_p = p.grad.data
        if weight_decay1 != 0:
          d_p.add_(weight_decay1, torch.sign(p.data))
        if weight_decay2 != 0:
          d_p.add_(weight_decay2, p.data)
        if momentum != 0:
          param_state = self.state[p]
          if 'momentum_buffer' not in param_state:
            buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
            buf.mul_(momentum).add_(d_p)
          else:
            buf = param_state['momentum_buffer']
            buf.mul_(momentum).add_(1 - dampening, d_p)
          if nesterov:
            d_p = d_p.add(momentum, buf)
          else:
            d_p = buf

        p.data.add_(-group['lr'], d_p)

    return loss

一个使用的例子:

optimizer = SGD([
        {'params': model.features12.parameters()},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

以上这篇关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
初学python数组的处理代码
Jan 04 Python
python 实现文件的递归拷贝实现代码
Aug 02 Python
初步解析Python中的yield函数的用法
Apr 03 Python
在Django中创建第一个静态视图
Jul 15 Python
深入讲解Python函数中参数的使用及默认参数的陷阱
Mar 13 Python
使用pycharm生成代码模板的实例
May 23 Python
python读取LMDB中图像的方法
Jul 02 Python
Python 实现文件打包、上传与校验的方法
Feb 13 Python
python web框架中实现原生分页
Sep 08 Python
详解使用python3.7配置开发钉钉群自定义机器人(2020年新版攻略)
Apr 01 Python
python代码实现将列表中重复元素之间的内容全部滤除
May 22 Python
PyTorch中的拷贝与就地操作详解
Dec 09 Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
You might like
yii框架源码分析之创建controller代码
2011/06/28 PHP
thinkphp,onethink和thinkox中验证码不显示的解决方法分析
2016/06/06 PHP
js left,right,mid函数
2008/06/10 Javascript
强大的jquery插件jqeuryUI做网页对话框效果!简单
2011/04/14 Javascript
JS.getTextContent(element,preformatted)使用介绍
2013/09/21 Javascript
JS批量修改PS中图层名称的方法
2014/01/26 Javascript
js跳转页面方法总结
2014/01/29 Javascript
jquery实现的蓝色二级导航条效果代码
2015/08/24 Javascript
学习Javascript闭包(Closure)知识
2016/08/07 Javascript
微信小程序 animation API详解及实例代码
2016/10/08 Javascript
JQuery实现列表中复选框全选反选功能封装(推荐)
2016/11/24 Javascript
微信开发 JS-SDK 6.0.2 经常遇到问题总结
2016/12/08 Javascript
AngularJS实现使用路由切换视图的方法
2017/01/24 Javascript
AngularJS2 与 D3.js集成实现自定义可视化的方法
2017/12/01 Javascript
微信小程序五子棋游戏的悔棋实现方法【附demo源码下载】
2019/02/20 Javascript
vue cli安装使用less的教程详解
2019/07/12 Javascript
vuex存储复杂参数(如对象数组等)刷新数据丢失的解决方法
2019/11/05 Javascript
vue2.x数组劫持原理的实现
2020/04/19 Javascript
使用React代码动态生成栅格布局的方法
2020/05/24 Javascript
Node.js中出现未捕获异常的处理方法
2020/06/29 Javascript
python实现定时同步本机与北京时间的方法
2015/03/24 Python
Python多线程编程简单介绍
2015/04/13 Python
Python的组合模式与责任链模式编程示例
2016/02/02 Python
python matplotlib画图实例代码分享
2017/12/27 Python
PyQt5事件处理之定时在控件上显示信息的代码
2020/03/25 Python
最新的小工具和卓越的产品设计:Oh That Tech!
2019/08/07 全球购物
泰国第一在线超市:Tops
2021/02/13 全球购物
黄继光的英雄事迹材料
2014/02/13 职场文书
国培计划培训感言
2014/03/11 职场文书
采购求职信
2014/03/17 职场文书
竞聘书模板
2014/03/31 职场文书
2014年城管个人工作总结
2014/12/08 职场文书
银行招聘自荐信
2015/03/06 职场文书
民事起诉书范本
2015/05/19 职场文书
如何用threejs实现实时多边形折射
2021/05/07 Javascript
MySQL实现配置主从复制项目实践
2022/03/31 MySQL