关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)


Posted in Python onFebruary 20, 2020

torch.optim的灵活使用详解

1. 基本用法:

要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项,

例如学习速率,重量衰减值等。

注:如果要把model放在GPU中,需要在构建一个Optimizer之前就执行model.cuda(),确保优化器里面的参数也是在GPU中。

例子:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)

2. 灵活的设置各层的学习率

将model中需要进行BP的层的参数送到torch.optim中,这些层不一定是连续的。

这个时候,Optimizer的参数不是一个可迭代的变量,而是一个可迭代的字典

(字典的key必须包含'params'(查看源码可以得知optimizer通过'params'访问parameters),

其他的key就是optimizer可以接受的,比如说'lr','weight_decay'),可以将这些字典构成一个list,

这样就是一个可迭代的字典了。

注:这个时候,可以在optimizer设置选项作为关键字参数传递,这时它们将被认为是默认值(当字典里面没有这个关键字参数key-value对时,就使用这个默认的参数)

This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

例子:

optimizer = SGD([
        {'params': model.features12.parameters(), 'lr': 1e-2},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

上面创建的optim.SGD类型的Optimizer,lr默认值为1e-1,momentum默认值为0.9。features12的参数学习率为1e-2。

灵活更改各层的学习率

torch.optim.optimizer.Optimizer的初始化函数如下:

__init__(self, params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

params (iterable): iterable of parameters to optimize or dicts defining parameter groups (params可以是可迭代的参数,或者一个定义参数组的字典,如上所示,字典的键值包括:params,lr,momentum,dampening,weight_decay,nesterov)

想要改变各层的学习率,可以访问optimizer的param_groups属性。type(optimizer.param_groups) -> list

optimizer.param_groups[0].keys()
Out[21]: ['dampening', 'nesterov', 'params', 'lr', 'weight_decay', 'momentum']

因此,想要更改某层参数的学习率,可以访问optimizer.param_groups,指定某个索引更改'lr'参数就可以。

def adjust_learning_rate(optimizer, decay_rate=0.9):
  for para in optimizer.param_groups:
    para['lr'] = para['lr']*decay_rate

重写torch.optim,加上L1正则

查看torch.optim.SGD等Optimizer的源码,发现没有L1正则的选项,而L1正则更容易得到稀疏解。

这个时候,可以更改/home/smiles/anaconda2/lib/python2.7/site-packages/torch/optim/sgd.py文件,模拟L2正则化的操作。

L1正则化求导如下:

dw = 1 * sign(w)

更改后的sgd.py如下:

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
  def __init__(self, params, lr=required, momentum=0, dampening=0,
         weight_decay1=0, weight_decay2=0, nesterov=False):
    defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
            weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
    if nesterov and (momentum <= 0 or dampening != 0):
      raise ValueError("Nesterov momentum requires a momentum and zero dampening")
    super(SGD, self).__init__(params, defaults)

  def __setstate__(self, state):
    super(SGD, self).__setstate__(state)
    for group in self.param_groups:
      group.setdefault('nesterov', False)

  def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    for group in self.param_groups:
      weight_decay1 = group['weight_decay1']
      weight_decay2 = group['weight_decay2']
      momentum = group['momentum']
      dampening = group['dampening']
      nesterov = group['nesterov']

      for p in group['params']:
        if p.grad is None:
          continue
        d_p = p.grad.data
        if weight_decay1 != 0:
          d_p.add_(weight_decay1, torch.sign(p.data))
        if weight_decay2 != 0:
          d_p.add_(weight_decay2, p.data)
        if momentum != 0:
          param_state = self.state[p]
          if 'momentum_buffer' not in param_state:
            buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
            buf.mul_(momentum).add_(d_p)
          else:
            buf = param_state['momentum_buffer']
            buf.mul_(momentum).add_(1 - dampening, d_p)
          if nesterov:
            d_p = d_p.add(momentum, buf)
          else:
            d_p = buf

        p.data.add_(-group['lr'], d_p)

    return loss

一个使用的例子:

optimizer = SGD([
        {'params': model.features12.parameters()},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

以上这篇关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在win和Linux系统中python命令行运行的不同
Jul 03 Python
python面向对象_详谈类的继承与方法的重载
Jun 07 Python
Python实现pdf文档转txt的方法示例
Jan 19 Python
VScode编写第一个Python程序HelloWorld步骤
Apr 06 Python
python实现音乐下载器
Apr 15 Python
对PyTorch torch.stack的实例讲解
Jul 30 Python
Python pymongo模块常用操作分析
Sep 01 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 Python
聊聊python里如何用Borg pattern实现的单例模式
Jun 06 Python
如何在Cloud Studio上执行Python代码?
Aug 09 Python
jupyter notebook 增加kernel教程
Apr 10 Python
BeautifulSoup中find和find_all的使用详解
Dec 07 Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
You might like
php实现的发送带附件邮件类实例
2014/09/22 PHP
PHP实现路由映射到指定控制器
2016/08/13 PHP
thinkphp3.2中实现phpexcel导出带生成图片示例
2017/02/14 PHP
PHP autoload使用方法及步骤详解
2020/09/05 PHP
用prototype实现的简单小巧的多级联动菜单
2007/03/24 Javascript
Jquery 选中表格一列并对表格排序实现原理
2012/12/15 Javascript
js同比例缩放图片的小例子
2013/10/30 Javascript
Node.js文件操作详解
2014/08/16 Javascript
JavaScript模拟可展开、拖动与关闭的聊天窗口实例
2015/05/12 Javascript
js控制网页前进和后退的方法
2015/06/08 Javascript
微信小程序 Toast自定义实例详解
2017/01/20 Javascript
基于JavaScript实现瀑布流效果
2017/03/29 Javascript
js导出Excel表格超出26位英文字符的解决方法ES6
2017/11/15 Javascript
vue异步axios获取的数据渲染到页面的方法
2018/08/09 Javascript
解决vue初始化项目一直停在downloading template的问题
2020/11/09 Javascript
vue-cli 3如何使用vue-bootstrap-datetimepicker日期插件
2021/02/20 Vue.js
[41:52]DOTA2-DPC中国联赛 正赛 CDEC vs Dynasty BO3 第二场 2月22日
2021/03/11 DOTA
Python for循环生成列表的实例
2018/06/15 Python
Python flask框架post接口调用示例
2019/07/03 Python
Django实现分页显示效果
2019/10/31 Python
简单了解python字符串前面加r,u的含义
2019/12/26 Python
python实现将列表中各个值快速赋值给多个变量
2020/04/02 Python
Python闭包装饰器使用方法汇总
2020/06/29 Python
美国最大的珠宝商之一:Littman Jewelers
2016/11/13 全球购物
欧洲顶级的童装奢侈品购物网站:Bambini Fashion(面向全球)
2018/04/24 全球购物
皇家阿尔伯特英国官方商店:Royal Albert骨瓷
2019/03/25 全球购物
Seavenger官网:潜水服、浮潜、靴子和袜子
2020/03/05 全球购物
大学生年度自我鉴定
2013/10/31 职场文书
护理学专业推荐信
2013/12/03 职场文书
《藏戏》教学反思
2014/02/11 职场文书
公司中层干部的自我评价分享
2014/03/01 职场文书
新闻学专业职业生涯规划范文:我的人生我做主
2014/09/12 职场文书
防暑降温通知书
2015/04/27 职场文书
JDBC连接的六步实例代码(与mysql连接)
2021/05/12 MySQL
Python实现猜拳与猜数字游戏的方法详解
2022/04/06 Python