pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用Socket(Https)Post登录百度的实现代码
May 18 Python
python中__slots__用法实例
Jun 04 Python
Django实现的自定义访问日志模块示例
Jun 23 Python
PyQt4实现下拉菜单可供选择并打印出来
Apr 20 Python
python保存数据到本地文件的方法
Jun 23 Python
使用 Python 玩转 GitHub 的贡献板(推荐)
Apr 04 Python
python打包exe开机自动启动的实例(windows)
Jun 28 Python
使用python实现kNN分类算法
Oct 16 Python
Python中的引用和拷贝实例解析
Nov 14 Python
pycharm修改file type方式
Nov 19 Python
Python 中 sorted 如何自定义比较逻辑
Feb 02 Python
python unittest单元测试的步骤分析
Aug 02 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
php中用foreach来操作数组的代码
2011/07/17 PHP
php函数array_merge用法一例(合并同类数组)
2013/02/03 PHP
浅析十款PHP开发框架的对比
2013/07/05 PHP
php多功能图片处理类分享(php图片缩放类)
2014/03/14 PHP
跟我学Laravel之安装Laravel
2014/10/15 PHP
Javascript实例教程(19) 使用HoTMetal(7)
2006/12/23 Javascript
深入理解JavaScript系列(11) 执行上下文(Execution Contexts)
2012/01/15 Javascript
Javascript创建自定义对象 创建Object实例添加属性和方法
2012/06/04 Javascript
jQuery中:radio选择器用法实例
2015/01/03 Javascript
jQuery中offsetParent()方法用法实例
2015/01/19 Javascript
jquery滚动加载数据的方法
2015/03/09 Javascript
JQ选择器_选择同类元素的第N个子元素的实现方法
2016/09/08 Javascript
vue中实现图片和文件上传的示例代码
2018/03/16 Javascript
微信小程序开发之改变data中数组或对象的某一属性值
2018/07/05 Javascript
VUE 实现滚动监听 导航栏置顶的方法
2018/09/11 Javascript
微信小程序JS加载esmap地图的实例详解
2019/09/04 Javascript
微信小程序实现蓝牙打印
2019/09/23 Javascript
react MPA 多页配置详解
2019/10/18 Javascript
小程序跳转到的H5页面再跳转回跳小程序的方法
2020/03/06 Javascript
基于原生JS封装的Modal对话框插件的示例代码
2020/09/09 Javascript
Express 配置HTML页面访问的实现
2020/11/01 Javascript
[01:14]DOTA2亚洲邀请赛小组赛赛前花絮
2017/03/27 DOTA
python文件操作相关知识点总结整理
2016/02/22 Python
解决Pandas to_json()中文乱码,转化为json数组的问题
2018/05/10 Python
PyQt4 treewidget 选择改变颜色,并设置可编辑的方法
2019/06/17 Python
Python 多个图同时在不同窗口显示的实现方法
2019/07/07 Python
详解python logging日志传输
2020/07/01 Python
python3中calendar返回某一时间点实例讲解
2020/11/18 Python
美国购买当代和现代家具网站:MODTEMPO
2018/07/20 全球购物
沃尔玛旗下墨西哥超市:Bodega Aurrera
2020/11/13 全球购物
ajax是什么及其工作原理
2012/02/08 面试题
驾驶员岗位职责
2014/01/29 职场文书
党员评议思想汇报
2014/10/08 职场文书
银行给客户的感谢信
2015/01/23 职场文书
css实现左上角飘带效果的完整代码
2022/03/18 HTML / CSS
如何在Python中妥善使用进度条详解
2022/04/05 Python