关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)


Posted in Python onFebruary 20, 2020

torch.optim的灵活使用详解

1. 基本用法:

要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项,

例如学习速率,重量衰减值等。

注:如果要把model放在GPU中,需要在构建一个Optimizer之前就执行model.cuda(),确保优化器里面的参数也是在GPU中。

例子:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)

2. 灵活的设置各层的学习率

将model中需要进行BP的层的参数送到torch.optim中,这些层不一定是连续的。

这个时候,Optimizer的参数不是一个可迭代的变量,而是一个可迭代的字典

(字典的key必须包含'params'(查看源码可以得知optimizer通过'params'访问parameters),

其他的key就是optimizer可以接受的,比如说'lr','weight_decay'),可以将这些字典构成一个list,

这样就是一个可迭代的字典了。

注:这个时候,可以在optimizer设置选项作为关键字参数传递,这时它们将被认为是默认值(当字典里面没有这个关键字参数key-value对时,就使用这个默认的参数)

This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

例子:

optimizer = SGD([
        {'params': model.features12.parameters(), 'lr': 1e-2},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

上面创建的optim.SGD类型的Optimizer,lr默认值为1e-1,momentum默认值为0.9。features12的参数学习率为1e-2。

灵活更改各层的学习率

torch.optim.optimizer.Optimizer的初始化函数如下:

__init__(self, params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

params (iterable): iterable of parameters to optimize or dicts defining parameter groups (params可以是可迭代的参数,或者一个定义参数组的字典,如上所示,字典的键值包括:params,lr,momentum,dampening,weight_decay,nesterov)

想要改变各层的学习率,可以访问optimizer的param_groups属性。type(optimizer.param_groups) -> list

optimizer.param_groups[0].keys()
Out[21]: ['dampening', 'nesterov', 'params', 'lr', 'weight_decay', 'momentum']

因此,想要更改某层参数的学习率,可以访问optimizer.param_groups,指定某个索引更改'lr'参数就可以。

def adjust_learning_rate(optimizer, decay_rate=0.9):
  for para in optimizer.param_groups:
    para['lr'] = para['lr']*decay_rate

重写torch.optim,加上L1正则

查看torch.optim.SGD等Optimizer的源码,发现没有L1正则的选项,而L1正则更容易得到稀疏解。

这个时候,可以更改/home/smiles/anaconda2/lib/python2.7/site-packages/torch/optim/sgd.py文件,模拟L2正则化的操作。

L1正则化求导如下:

dw = 1 * sign(w)

更改后的sgd.py如下:

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
  def __init__(self, params, lr=required, momentum=0, dampening=0,
         weight_decay1=0, weight_decay2=0, nesterov=False):
    defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
            weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
    if nesterov and (momentum <= 0 or dampening != 0):
      raise ValueError("Nesterov momentum requires a momentum and zero dampening")
    super(SGD, self).__init__(params, defaults)

  def __setstate__(self, state):
    super(SGD, self).__setstate__(state)
    for group in self.param_groups:
      group.setdefault('nesterov', False)

  def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    for group in self.param_groups:
      weight_decay1 = group['weight_decay1']
      weight_decay2 = group['weight_decay2']
      momentum = group['momentum']
      dampening = group['dampening']
      nesterov = group['nesterov']

      for p in group['params']:
        if p.grad is None:
          continue
        d_p = p.grad.data
        if weight_decay1 != 0:
          d_p.add_(weight_decay1, torch.sign(p.data))
        if weight_decay2 != 0:
          d_p.add_(weight_decay2, p.data)
        if momentum != 0:
          param_state = self.state[p]
          if 'momentum_buffer' not in param_state:
            buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
            buf.mul_(momentum).add_(d_p)
          else:
            buf = param_state['momentum_buffer']
            buf.mul_(momentum).add_(1 - dampening, d_p)
          if nesterov:
            d_p = d_p.add(momentum, buf)
          else:
            d_p = buf

        p.data.add_(-group['lr'], d_p)

    return loss

一个使用的例子:

optimizer = SGD([
        {'params': model.features12.parameters()},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

以上这篇关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
让python的Cookie.py模块支持冒号做key的方法
Dec 28 Python
Python UnicodeEncodeError: 'gbk' codec can't encode character 解决方法
Apr 24 Python
Phantomjs抓取渲染JS后的网页(Python代码)
May 13 Python
matplotlib简介,安装和简单实例代码
Dec 26 Python
python机器学习理论与实战(六)支持向量机
Jan 19 Python
Django2.1.3 中间件使用详解
Nov 26 Python
pycharm运行程序时在Python console窗口中运行的方法
Dec 03 Python
对python数据切割归并算法的实例讲解
Dec 12 Python
PyCharm 设置SciView工具窗口的方法
Jan 15 Python
python实现两个文件夹的同步
Aug 29 Python
django admin 根据choice字段选择的不同来显示不同的页面方式
May 13 Python
基于PyTorch实现一个简单的CNN图像分类器
May 29 Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
You might like
PHP clearstatcache()函数详解
2010/03/02 PHP
linux iconv方法的使用
2011/10/01 PHP
使用php验证复选框有效性的示例
2013/11/13 PHP
Thinkphp中Create方法深入探究
2014/06/16 PHP
php实现在多维数组中查找特定value的方法
2015/07/29 PHP
php 问卷调查结果统计
2015/10/08 PHP
JavaScript格式化数字的函数代码
2010/11/30 Javascript
20款超赞的jQuery插件 Web开发人员必备
2011/02/26 Javascript
jquery动态更换设置背景图的方法
2014/03/25 Javascript
easyUI下拉列表点击事件使用方法
2017/05/18 Javascript
JS实现自动轮播图效果(自适应屏幕宽度+手机触屏滑动)
2017/06/19 Javascript
详解vue axios中文文档
2017/09/12 Javascript
浅谈针对Vue相同路由不同参数的刷新问题
2018/09/29 Javascript
Angular使用Restful的增删改
2018/12/28 Javascript
js刷新页面location.reload()用法详解
2019/12/09 Javascript
基于vue+echarts 数据可视化大屏展示的方法示例
2020/03/09 Javascript
[03:17]2014DOTA2 国际邀请赛中国区预选赛 四强专访
2014/05/23 DOTA
[48:46]完美世界DOTA2联赛PWL S2 SZ vs FTD.C 第二场 11.19
2020/11/19 DOTA
python实现图片识别汽车功能
2018/11/30 Python
Python面向对象程序设计类的封装与继承用法示例
2019/04/12 Python
python高斯分布概率密度函数的使用详解
2019/07/10 Python
Python 使用 PyMysql、DBUtils 创建连接池提升性能
2019/08/14 Python
djano一对一、多对多、分页实例代码
2019/08/16 Python
无需压缩软件,用python帮你操作压缩包
2020/08/17 Python
python如何用matplotlib创建三维图表
2021/01/26 Python
CSS3制作文字半透明倒影效果的两种实现方式
2014/08/08 HTML / CSS
微软香港官网及网上商店:Microsoft HK
2016/09/01 全球购物
英国时尚和家居用品零售商:Matalan
2021/02/28 全球购物
大专学生求职自荐信
2014/07/06 职场文书
大一新生期末自我评价
2014/09/12 职场文书
教育实践活动对照检查材料
2014/09/23 职场文书
有限责任公司股东合作协议书
2014/12/02 职场文书
中学社团活动总结
2015/05/07 职场文书
新郎接新娘保证书
2015/05/08 职场文书
运动会加油稿30字
2015/07/21 职场文书
2016学习全国教书育人楷模先进事迹心得体会
2016/01/21 职场文书