pytorch载入预训练模型后,实现训练指定层


Posted in Python onJanuary 06, 2020

1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:

pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)

strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。

2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  if name 满足某些条件:
    value.requires_grad = False

# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:

# 载入预训练模型参数后...
for name, value in model.named_parameters():
  print(name)
# 或
print(model.state_dict().keys())

假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:

'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',

假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
               {'params':model.decoder.parameters()}
               ],
               lr=1e-4, momentum=0.9)

代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。

在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。

以上这篇pytorch载入预训练模型后,实现训练指定层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

参考:

Python 相关文章推荐
videocapture库制作python视频高速传输程序
Dec 23 Python
Python实现计算文件夹下.h和.cpp文件的总行数
Apr 23 Python
Python实现将绝对URL替换成相对URL的方法
Jun 28 Python
使用python为mysql实现restful接口
Jan 05 Python
详解python使用Nginx和uWSGI来运行Python应用
Jan 09 Python
取numpy数组的某几行某几列方法
Apr 03 Python
对Python中type打开文件的方式介绍
Apr 28 Python
Python使用Windows API创建窗口示例【基于win32gui模块】
May 09 Python
利用matplotlib为图片上添加触发事件进行交互
Apr 23 Python
Scrapy基于scrapy_redis实现分布式爬虫部署的示例
Sep 29 Python
在终端启动Python时报错的解决方案
Nov 20 Python
python 实现的IP 存活扫描脚本
Dec 10 Python
python与mysql数据库交互的实现
Jan 06 #Python
win10系统下python3安装及pip换源和使用教程
Jan 06 #Python
基于python实现文件加密功能
Jan 06 #Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 #Python
如何使用python实现模拟鼠标点击
Jan 06 #Python
pytorch 实现查看网络中的参数
Jan 06 #Python
Python3 虚拟开发环境搭建过程(图文详解)
Jan 06 #Python
You might like
第七节 类的静态成员 [7]
2006/10/09 PHP
一道求$b相对于$a的相对路径的php代码
2010/08/08 PHP
PHP文章采集URL补全函数(FormatUrl)
2012/08/02 PHP
比较discuz和ecshop的截取字符串函数php版
2012/09/03 PHP
领悟php接口中interface存在的意义
2013/06/27 PHP
php获取文件内容最后一行示例
2014/01/09 PHP
thinkphp制作404跳转页的简单实现方法
2016/09/22 PHP
详解php实现页面静态化原理
2017/06/21 PHP
JavaScript中圆括号()和方括号[]的特殊用法疑问解答
2013/08/06 Javascript
IE、FF浏览器下修改标签透明度
2014/01/28 Javascript
jQuery避免$符和其他JS库冲突的方法对比
2014/02/20 Javascript
自己用jQuery写了一个图片的马赛克消失效果
2014/05/04 Javascript
常用的JS验证和函数汇总
2014/12/23 Javascript
javascript实现详细时间提醒信息效果的方法
2015/03/11 Javascript
利用jQuery实现打字机字幕效果实例代码
2016/09/02 Javascript
vue组件间通信解析
2017/03/01 Javascript
ztree实现左边动态生成树右边为内容详情功能
2017/11/03 Javascript
详解Angular中实现自定义组件的双向绑定的两种方法
2018/11/23 Javascript
浅谈Vue CLI 3结合Lerna进行UI框架设计
2019/04/14 Javascript
解决vue单页面多个组件嵌套监听浏览器窗口变化问题
2020/07/30 Javascript
vant-ui组件调用Dialog弹窗异步关闭操作
2020/11/04 Javascript
[40:57]TI4 循环赛第二日 iG vs EG
2014/07/11 DOTA
python采用requests库模拟登录和抓取数据的简单示例
2014/07/05 Python
Python之自动获取公网IP的实例讲解
2017/10/01 Python
Python环境搭建之OpenCV的步骤方法
2017/10/20 Python
解读python logging模块的使用方法
2018/04/17 Python
django+mysql的使用示例
2018/11/23 Python
python 制作简单的音乐播放器
2020/11/25 Python
css3实现二维码扫描特效的示例
2020/10/29 HTML / CSS
Structs界面控制层技术
2013/10/11 面试题
儿科护理实习自我鉴定
2013/09/19 职场文书
《最佳路径》教学反思
2014/04/13 职场文书
小学语文课后反思精选
2014/04/25 职场文书
港澳通行证委托书怎么写
2014/08/02 职场文书
综合测评自我评价
2015/03/06 职场文书
组织委员竞选稿
2015/11/21 职场文书