pytorch载入预训练模型后,实现训练指定层


Posted in Python onJanuary 06, 2020

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  if name 满足某些条件:
    value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
               {'params':model.decoder.parameters()}
               ],
               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。

在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

以上这篇pytorch载入预训练模型后,实现训练指定层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

参考:

Python 相关文章推荐
python写的一个文本编辑器
Jan 23 Python
使用python编写批量卸载手机中安装的android应用脚本
Jul 21 Python
Python编写百度贴吧的简单爬虫
Apr 02 Python
Python中对数组集进行按行打乱shuffle的方法
Nov 08 Python
Python编程在flask中模拟进行Restful的CRUD操作
Dec 28 Python
python实现桌面气泡提示功能
Jul 29 Python
解决Python对齐文本字符串问题
Aug 28 Python
解决django 向mysql中写入中文字符出错的问题
May 18 Python
python代码实现将列表中重复元素之间的内容全部滤除
May 22 Python
安装pytorch时报sslerror错误的解决方案
May 17 Python
使用Pytorch训练two-head网络的操作
May 28 Python
Python Pandas pandas.read_sql_query函数实例用法分析
Jun 21 Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
如何使用python实现模拟鼠标点击
Jan 06 #Python
pytorch 实现查看网络中的参数
Jan 06 #Python
Python3 虚拟开发环境搭建过程(图文详解)
Jan 06 #Python
You might like
php在服务器执行exec命令失败的解决方法
2012/03/03 PHP
php中HTTP_REFERER函数用法实例
2014/11/21 PHP
PHP实现获取并生成数据库字典的方法
2016/05/04 PHP
基于swoole实现多人聊天室
2018/06/14 PHP
PHP使用gearman进行异步的邮件或短信发送操作详解
2020/02/27 PHP
脚本吧 - 幻宇工作室用到js,超强推荐share.js
2006/12/23 Javascript
JavaScript 学习笔记(十三)Dom创建表格
2010/01/21 Javascript
jquery异步调用页面后台方法‏(asp.net)
2011/03/01 Javascript
JQuery入门—JQuery程序的代码风格详细介绍
2013/01/03 Javascript
js拼接html注意问题示例探讨
2014/07/14 Javascript
使用jquery+CSS实现控制打印样式
2014/12/31 Javascript
JavaScript实现当网页加载完成后执行指定函数的方法
2015/03/21 Javascript
jQuery使用addClass()方法给元素添加多个class样式
2015/03/26 Javascript
在JavaScript中操作数组之map()方法的使用
2015/06/09 Javascript
前端学习笔记style,currentStyle,getComputedStyle的用法与区别
2016/05/28 Javascript
AngularJS ng-app 指令实例详解
2016/07/30 Javascript
jQuery实现页码跳转式动态数据分页
2017/12/31 jQuery
JS实现获取数组中最大值或最小值功能示例
2019/03/02 Javascript
vue实现一个获取按键展示快捷键效果的Input组件
2021/01/13 Vue.js
[02:56]《DAC最前线》之国外战队抵达上海备战亚洲邀请赛
2015/01/28 DOTA
[01:29:31]VP VS VG Supermajor小组赛胜者组第二轮 BO3第一场 6.2
2018/06/03 DOTA
分享一下Python 开发者节省时间的10个方法
2015/10/02 Python
教你用python3根据关键词爬取百度百科的内容
2016/08/18 Python
python实时检测键盘输入函数的示例
2019/07/17 Python
PyQT5 实现快捷键复制表格数据的方法示例
2020/06/19 Python
香港化妆品经销商:我的公主
2016/08/05 全球购物
西班牙伏林航空公司:Vueling
2016/08/05 全球购物
编程输出如下图形
2013/11/24 面试题
陈胜吴广起义口号
2014/06/20 职场文书
国际金融专业自荐信
2014/07/05 职场文书
《周恩来的四个昼夜》观后思想汇报范文两篇
2014/09/10 职场文书
大学生见习报告范文
2014/11/03 职场文书
2015年酒店销售部工作总结
2015/07/24 职场文书
朋友离别感言
2015/08/04 职场文书
初中班主任工作随笔
2015/08/15 职场文书
AJAX学习笔记
2021/05/18 Javascript