关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)


Posted in Python onFebruary 20, 2020

torch.optim的灵活使用详解

1. 基本用法:

要构建一个优化器Optimizer,必须给它一个包含参数的迭代器来优化,然后,我们可以指定特定的优化选项,

例如学习速率,重量衰减值等。

注:如果要把model放在GPU中,需要在构建一个Optimizer之前就执行model.cuda(),确保优化器里面的参数也是在GPU中。

例子:

optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum=0.9)

2. 灵活的设置各层的学习率

将model中需要进行BP的层的参数送到torch.optim中,这些层不一定是连续的。

这个时候,Optimizer的参数不是一个可迭代的变量,而是一个可迭代的字典

(字典的key必须包含'params'(查看源码可以得知optimizer通过'params'访问parameters),

其他的key就是optimizer可以接受的,比如说'lr','weight_decay'),可以将这些字典构成一个list,

这样就是一个可迭代的字典了。

注:这个时候,可以在optimizer设置选项作为关键字参数传递,这时它们将被认为是默认值(当字典里面没有这个关键字参数key-value对时,就使用这个默认的参数)

This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.

例子:

optimizer = SGD([
        {'params': model.features12.parameters(), 'lr': 1e-2},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

上面创建的optim.SGD类型的Optimizer,lr默认值为1e-1,momentum默认值为0.9。features12的参数学习率为1e-2。

灵活更改各层的学习率

torch.optim.optimizer.Optimizer的初始化函数如下:

__init__(self, params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)

params (iterable): iterable of parameters to optimize or dicts defining parameter groups (params可以是可迭代的参数,或者一个定义参数组的字典,如上所示,字典的键值包括:params,lr,momentum,dampening,weight_decay,nesterov)

想要改变各层的学习率,可以访问optimizer的param_groups属性。type(optimizer.param_groups) -> list

optimizer.param_groups[0].keys()
Out[21]: ['dampening', 'nesterov', 'params', 'lr', 'weight_decay', 'momentum']

因此,想要更改某层参数的学习率,可以访问optimizer.param_groups,指定某个索引更改'lr'参数就可以。

def adjust_learning_rate(optimizer, decay_rate=0.9):
  for para in optimizer.param_groups:
    para['lr'] = para['lr']*decay_rate

重写torch.optim,加上L1正则

查看torch.optim.SGD等Optimizer的源码,发现没有L1正则的选项,而L1正则更容易得到稀疏解。

这个时候,可以更改/home/smiles/anaconda2/lib/python2.7/site-packages/torch/optim/sgd.py文件,模拟L2正则化的操作。

L1正则化求导如下:

dw = 1 * sign(w)

更改后的sgd.py如下:

import torch
from torch.optim.optimizer import Optimizer, required

class SGD(Optimizer):
  def __init__(self, params, lr=required, momentum=0, dampening=0,
         weight_decay1=0, weight_decay2=0, nesterov=False):
    defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
            weight_decay1=weight_decay1, weight_decay2=weight_decay2, nesterov=nesterov)
    if nesterov and (momentum <= 0 or dampening != 0):
      raise ValueError("Nesterov momentum requires a momentum and zero dampening")
    super(SGD, self).__init__(params, defaults)

  def __setstate__(self, state):
    super(SGD, self).__setstate__(state)
    for group in self.param_groups:
      group.setdefault('nesterov', False)

  def step(self, closure=None):
    """Performs a single optimization step.

    Arguments:
      closure (callable, optional): A closure that reevaluates the model
        and returns the loss.
    """
    loss = None
    if closure is not None:
      loss = closure()

    for group in self.param_groups:
      weight_decay1 = group['weight_decay1']
      weight_decay2 = group['weight_decay2']
      momentum = group['momentum']
      dampening = group['dampening']
      nesterov = group['nesterov']

      for p in group['params']:
        if p.grad is None:
          continue
        d_p = p.grad.data
        if weight_decay1 != 0:
          d_p.add_(weight_decay1, torch.sign(p.data))
        if weight_decay2 != 0:
          d_p.add_(weight_decay2, p.data)
        if momentum != 0:
          param_state = self.state[p]
          if 'momentum_buffer' not in param_state:
            buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
            buf.mul_(momentum).add_(d_p)
          else:
            buf = param_state['momentum_buffer']
            buf.mul_(momentum).add_(1 - dampening, d_p)
          if nesterov:
            d_p = d_p.add(momentum, buf)
          else:
            d_p = buf

        p.data.add_(-group['lr'], d_p)

    return loss

一个使用的例子:

optimizer = SGD([
        {'params': model.features12.parameters()},
        {'params': model.features22.parameters()},
        {'params': model.features32.parameters()},
        {'params': model.features42.parameters()},
        {'params': model.features52.parameters()},
      ], weight_decay1=5e-4, lr=1e-1, momentum=0.9)

以上这篇关于torch.optim的灵活使用详解(包括重写SGD,加上L1正则)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python下MySQLdb用法实例分析
Jun 08 Python
python3中str(字符串)的使用教程
Mar 23 Python
对python实时得到鼠标位置的示例讲解
Oct 14 Python
详解Python数据可视化编程 - 词云生成并保存(jieba+WordCloud)
Mar 26 Python
python打开windows应用程序的实例
Jun 28 Python
Python CVXOPT模块安装及使用解析
Aug 01 Python
pycharm创建scrapy项目教程及遇到的坑解析
Aug 15 Python
python模拟预测一下新型冠状病毒肺炎的数据
Feb 01 Python
Python实现http接口自动化测试的示例代码
Oct 09 Python
PyCharm2019.3永久激活破解详细图文教程,亲测可用(不定期更新)
Oct 29 Python
Python 求向量的余弦值操作
Mar 04 Python
Python sklearn分类决策树方法详解
Sep 23 Python
Python sys模块常用方法解析
Feb 20 #Python
pytorch 实现在一个优化器中设置多个网络参数的例子
Feb 20 #Python
pytorch ImageFolder的覆写实例
Feb 20 #Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
You might like
php跨域cookie共享使用方法
2014/02/20 PHP
php中count获取多维数组长度的方法
2014/11/03 PHP
SAE实时日志接口SDK用法示例
2016/10/09 PHP
javascript学习笔记(十五) js间歇调用和超时调用
2012/06/20 Javascript
图片上传插件jquery.uploadify详解
2013/11/15 Javascript
浅析js中的浮点型运算问题
2014/01/06 Javascript
编写自己的jQuery提示框(Tip)插件
2015/02/05 Javascript
jQuery实现图片轮播特效代码分享
2015/09/15 Javascript
Jquery+Ajax+PHP+MySQL实现分类列表管理(上)
2015/10/28 Javascript
基于JQuery及AJAX实现名人名言随机生成器
2017/02/10 Javascript
JS表单验证方法实例小结【电话、身份证号、Email、中文、特殊字符、身份证号等】
2017/02/14 Javascript
ES6(ECMAScript 6)新特性之模板字符串用法分析
2017/04/01 Javascript
layui点击按钮添加可编辑的一行方法
2018/08/15 Javascript
微信公众号获取用户地理位置并列出附近的门店的示例代码
2019/07/25 Javascript
Vue单文件组件开发实现过程详解
2020/07/30 Javascript
JavaScript async/await原理及实例解析
2020/12/02 Javascript
Python 实现引用其他.py文件中的类和类的方法
2018/04/29 Python
树莓派动作捕捉抓拍存储图像脚本
2019/06/22 Python
python基础教程之while循环
2019/08/14 Python
python之PyQt按钮右键菜单功能的实现代码
2019/08/17 Python
python统计字符的个数代码实例
2020/02/07 Python
python实现交并比IOU教程
2020/04/16 Python
浅谈Python里面None True False之间的区别
2020/07/09 Python
Python 在 VSCode 中使用 IPython Kernel 的方法详解
2020/09/05 Python
10个最常见的HTML5面试题 附答案
2016/06/06 HTML / CSS
如何在发生故障的节点上重新安装 SQL Server
2013/03/14 面试题
既然说Ruby中一切都是对象,那么Ruby中类也是对象吗
2013/01/26 面试题
电大自我鉴定
2013/10/27 职场文书
打架检讨书100字
2014/01/08 职场文书
学生吸烟检讨书
2014/09/14 职场文书
2015年防灾减灾工作总结
2015/07/24 职场文书
2016年公司中秋节致辞
2015/11/26 职场文书
2016高校自主招生自荐信范文
2016/01/28 职场文书
HTML+CSS+JS实现图片的瀑布流布局的示例代码
2021/04/22 HTML / CSS
python四个坐标点对图片区域最小外接矩形进行裁剪
2021/06/04 Python
SQL SERVER触发器详解
2022/02/24 SQL Server