pytorch 预训练模型读取修改相关参数的填坑问题


Posted in Python onJune 05, 2021

pytorch 预训练模型读取修改相关参数的填坑

修改部分层,仍然调用之前的模型参数。

resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load(args.predir))
 
res_conv31 = Bottleneck_dilated(1024, 256,dilated_rate = 2)
print("---------------------",res_conv31)
print("---------------------",resnet.layer3[1])
 
res_conv31.load_state_dict(resnet.layer3[1].state_dict())

网络预训练模型与之前的模型对应不上,名称差个前缀

model_dict = model.state_dict()
# print(model_dict)
pretrained_dict = torch.load("/yzc/reid_testpcb/se_resnet50-ce0d4300.pth")
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

最后是修改参数名拿来用的,

from collections import OrderedDict
pretrained_dict = torch.load('premodel')
 
new_state_dict = OrderedDict()
 
# for k, v in mgn_state_dict.items():
#     name = k[7:]  # remove `module.`
#     new_state_dict[name] = v
# self.model = self.model.load_state_dict(new_state_dict)
 
for k, v in pretrained_dict.items():
    name = "model.module."+k   # remove `module.`
    # print(name)
    new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)

pytorch:加载预训练模型中的部分参数,并固定该部分参数(真实有效)

大家在学习pytorch时,可能想利用pytorch进行fine-tune,但是又烦恼于参数的加载问题。下面我将讲诉我的使用心得。

Step1: 加载预训练模型,并去除需要再次训练的层

#注意:需要重新训练的层的名字要和之前的不同。
model=resnet()#自己构建的模型,以resnet为例
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

Step2:固定部分参数

#k是可训练参数的名字,v是包含可训练参数的一个实体
#可以先print(k),找到自己想进行调整的层,并将该层的名字加入到if语句中:
for k,v in model.named_parameters():
    if k!='xxx.weight' and k!='xxx.bias' :
        v.requires_grad=False#固定参数

Step3:训练部分参数

#将要训练的参数放入优化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

Step4:检查部分参数是否固定

debug之后,程序正常运行,最好检查一下网络的参数是否真的被固定了,如何没固定,网络的状态接近于重新训练,可能会导致网络性能不稳定,也没办法得到想要得到的性能提升。

for k,v in model.named_parameters():
   if k!='xxx.weight' and k!='xxx.bias' :
   print(v.requires_grad)#理想状态下,所有值都是False

需要注意的是,操作失误最大的影响是,loss函数几乎不会发生变化,一直处于最开始的状态,这很可能是因为所有参数都被固定了。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
探索Python3.4中新引入的asyncio模块
Apr 08 Python
Python数据类型详解(四)字典:dict
May 12 Python
matplotlib绘制动画代码示例
Jan 02 Python
python3.6使用pymysql连接Mysql数据库
May 25 Python
python 定时器,实现每天凌晨3点执行的方法
Feb 20 Python
python并发编程多进程 互斥锁原理解析
Aug 20 Python
Python猴子补丁知识点总结
Jan 05 Python
Python3读写Excel文件(使用xlrd,xlsxwriter,openpyxl3种方式读写实例与优劣)
Feb 13 Python
Python列表如何更新值
May 27 Python
如何在python中处理配置文件代码实例
Sep 27 Python
pycharm 2020.2.4 pip install Flask 报错 Error:Non-zero exit code的问题
Dec 04 Python
有趣的二维码:使用MyQR和qrcode来制作二维码
May 10 Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
You might like
合并ThinkPHP配置文件以消除代码冗余的实现方法
2014/07/22 PHP
php生成二维码图片方法汇总
2016/12/17 PHP
Yii2框架实现数据库常用操作总结
2017/02/08 PHP
php合并数组并保留键值的实现方法
2018/03/12 PHP
删除重复数据的算法
2006/11/23 Javascript
JavaScript XML操作 封装类
2009/07/01 Javascript
JavaScript Cookie的读取和写入函数
2009/12/08 Javascript
jQuery 自动增长的文本输入框实现代码
2010/04/02 Javascript
中文路径导致unitpngfix.js不正常的解决方法
2013/06/26 Javascript
js识别不同浏览器基于userAgent做判断
2014/07/29 Javascript
jquery+javascript编写国籍控件
2015/02/12 Javascript
javascript实现实时输出当前的时间
2015/04/27 Javascript
JavaScript实现的浮动层框架用法实例分析
2015/10/10 Javascript
总结几道关于Node.js的面试问题
2017/01/11 Javascript
jQuery简单实现向列表动态添加新元素的方法示例
2017/12/25 jQuery
微信小程序停止其他视频播放当前视频的实例代码
2019/12/25 Javascript
python二叉树遍历的实现方法
2013/11/21 Python
在Python 3中实现类型检查器的简单方法
2015/07/03 Python
一篇文章读懂Python赋值与拷贝
2018/04/19 Python
python利用requests库进行接口测试的方法详解
2018/07/06 Python
python3中os.path模块下常用的用法总结【推荐】
2018/09/16 Python
Python设计模式之观察者模式原理与用法详解
2019/01/16 Python
Python 网络编程之TCP客户端/服务端功能示例【基于socket套接字】
2019/10/12 Python
关于Flask项目无法使用公网IP访问的解决方式
2019/11/19 Python
Python迷宫生成和迷宫破解算法实例
2019/12/24 Python
python 轮询执行某函数的2种方式
2020/05/03 Python
英国在线药房:Chemist.co.uk
2019/03/26 全球购物
以设计师精品品质提供快速时尚:Mostata
2019/05/10 全球购物
阿里巴巴的Oracle DBA笔试题答案-SQL tuning类
2016/04/03 面试题
会计应聘求职信范文
2013/12/17 职场文书
公司培训心得体会
2014/01/03 职场文书
餐饮商业计划书范文
2014/04/29 职场文书
驳回起诉裁定书
2015/05/19 职场文书
《蜜蜂引路》教学反思
2016/02/22 职场文书
golang goroutine顺序输出方式
2021/04/29 Golang
使用nginx配置访问wgcloud的方法
2021/06/26 Servers