python PyTorch参数初始化和Finetune


Posted in Python onFebruary 11, 2018

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

python PyTorch参数初始化和Finetune

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
  if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
  elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
    m.weight.data.fill_(1)
    m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
  param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)

# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
           model.parameters())

optimizer = torch.optim.SGD([
      {'params': base_params},
      {'params': model.fc.parameters(), 'lr': 1e-3}
      ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python打造出适合自己的定制化Eclipse IDE
Mar 02 Python
Python 3.x 连接数据库示例(pymysql 方式)
Jan 19 Python
Python实现的选择排序算法示例
Nov 29 Python
Django自定义manage命令实例代码
Feb 11 Python
Python实现获取前100组勾股数的方法示例
May 04 Python
Python3利用Dlib19.7实现摄像头人脸识别的方法
May 11 Python
python随机在一张图像上截取任意大小图片的方法
Jan 24 Python
Python 串口读写的实现方法
Jun 12 Python
利用python实现周期财务统计可视化
Aug 25 Python
python模块hashlib(加密服务)知识点讲解
Nov 25 Python
Python:二维列表下标互换方式(矩阵转置)
Dec 02 Python
django-csrf使用和禁用方式
Mar 13 Python
Python装饰器用法示例小结
Feb 11 #Python
python PyTorch预训练示例
Feb 11 #Python
TensorFlow中权重的随机初始化的方法
Feb 11 #Python
python的staticmethod与classmethod实现实例代码
Feb 11 #Python
Python语言的变量认识及操作方法
Feb 11 #Python
利用Opencv中Houghline方法实现直线检测
Feb 11 #Python
tensorflow输出权重值和偏差的方法
Feb 10 #Python
You might like
如何使用动态共享对象的模式来安装PHP
2006/10/09 PHP
资料注册后发信小技巧
2006/10/09 PHP
PHP has encountered an Access Violation 错误的解决方法
2010/01/17 PHP
什么是OneThink oneThink后台添加插件步骤
2016/04/13 PHP
拖动Html元素集合 Drag and Drop any item
2006/12/22 Javascript
JS 继承实例分析
2008/11/04 Javascript
在JS中如何调用JSP中的变量
2014/01/22 Javascript
当某个文本框成为焦点时即清除文本框内容
2014/04/28 Javascript
了不起的node.js读书笔记之mongodb数据库交互
2014/12/22 Javascript
JavaScript中对象介绍
2014/12/31 Javascript
JavaScript里实用的原生API汇总
2015/05/14 Javascript
Jquery和angularjs获取check框选中的值的方法汇总
2016/01/17 Javascript
vue.js 表格分页ajax 异步加载数据
2016/10/18 Javascript
微信小程序开发之toast等弹框提示使用教程
2017/06/08 Javascript
vue router仿天猫底部导航栏功能
2017/10/18 Javascript
vue.js 底部导航栏 一级路由显示 子路由不显示的解决方法
2018/03/09 Javascript
vue打包的时候自动将px转成rem的操作方法
2018/06/20 Javascript
详解Vue组件之间通信的七种方式
2019/04/14 Javascript
vue中axios的二次封装实例讲解
2019/10/14 Javascript
微信小程序自定义navigationBar顶部导航栏适配所有机型(附完整案例)
2020/04/26 Javascript
js实现贪吃蛇游戏(简易版)
2020/09/29 Javascript
python中多层嵌套列表的拆分方法
2018/07/02 Python
Python logging模块用法示例
2018/08/28 Python
Python logging设置和logger解析
2019/08/28 Python
Pytorch修改ResNet模型全连接层进行直接训练实例
2019/09/10 Python
用 python 进行微信好友信息分析
2020/11/28 Python
中国旅游网站:同程旅游
2016/09/11 全球购物
Gloeilampgoedkoop荷兰:在线购买灯泡
2019/02/16 全球购物
西班牙最大的在线滑板和街头服饰商店:Fillow.net
2019/04/15 全球购物
傲盾软件面试题
2015/08/17 面试题
介绍下static、final、abstract区别
2015/01/30 面试题
临床医学专业个人的自我评价
2013/09/27 职场文书
关于工作时间玩手机的检讨书
2014/09/18 职场文书
村党支部书记四风问题个人对照检查材料思想汇报
2014/10/06 职场文书
Vue3.0写自定义指令的简单步骤记录
2021/06/27 Vue.js
前端JavaScript大管家 package.json
2021/11/02 Javascript