解决Pytorch修改预训练模型时遇到key不匹配的情况


Posted in Python onJune 05, 2021

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题

#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights)

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载

from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict)

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.3water.com/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

对'num_batches_tracked进行过滤:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python遍历C盘dll文件的方法
May 06 Python
Python实现给qq邮箱发送邮件的方法
May 28 Python
深入讲解Python中的迭代器和生成器
Oct 26 Python
python2.7+selenium2实现淘宝滑块自动认证功能
Feb 24 Python
python机器学习之贝叶斯分类
Mar 26 Python
VSCode下配置python调试运行环境的方法
Apr 06 Python
Python 3.3实现计算两个日期间隔秒数/天数的方法示例
Jan 07 Python
Python 实现微信防撤回功能
Apr 29 Python
Python tkinter实现图片标注功能(完整代码)
Dec 08 Python
python实现对变位词的判断方法
Apr 05 Python
Python3获取cookie常用三种方案
Oct 05 Python
Python语言中的数据类型-序列
Feb 24 Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
You might like
PHP实现的简单异常处理类示例
2017/05/04 PHP
PHP mysqli事务操作常用方法分析
2017/07/22 PHP
PHP检查URL包含特定字符串实例方法
2019/02/11 PHP
小程序微信退款功能实现方法详解【基于thinkPHP】
2019/05/05 PHP
页面中body onload 和 window.onload 冲突的问题的解决
2009/07/01 Javascript
jquery判断复选框是否选中进行答题提示特效
2015/12/10 Javascript
探究Javascript模板引擎mustache.js使用方法
2016/01/26 Javascript
js获取所有checkbox的值的简单实例
2016/05/30 Javascript
React-Native使用Mobx实现购物车功能
2017/09/14 Javascript
微信小程序getPhoneNumber获取用户手机号
2017/09/29 Javascript
vue使用rem实现 移动端屏幕适配
2018/09/26 Javascript
微信小程序实现带缩略图轮播效果
2018/11/04 Javascript
浅谈Vue CLI 3结合Lerna进行UI框架设计
2019/04/14 Javascript
详解element-ui表格中勾选checkbox,高亮当前行
2019/09/02 Javascript
vue.js中ref及$refs的使用方法解析
2019/10/08 Javascript
Vue实现商品飞入购物车效果(电商项目)
2019/11/26 Javascript
Python使用Pycrypto库进行RSA加密的方法详解
2016/06/06 Python
Python实现简单的多任务mysql转xml的方法
2017/02/08 Python
tensorflow学习笔记之简单的神经网络训练和测试
2018/04/15 Python
Python3内置模块之json编解码方法小结【推荐】
2020/12/09 Python
PyCharm无法引用自身项目解决方式
2020/02/12 Python
CSS3 选择器 基本选择器介绍
2012/01/21 HTML / CSS
Amaze UI 文件选择域的示例代码
2020/08/26 HTML / CSS
玩具反斗城天猫官方旗舰店:享誉全球的玩具店
2017/10/10 全球购物
用C#语言写出与SQLSERVER访问时的具体过程
2013/04/16 面试题
安全资料员岗位职责
2013/12/14 职场文书
办公室副主任职责范本
2014/03/08 职场文书
分公司经理任命书
2014/06/05 职场文书
自我推荐信怎么写
2015/03/24 职场文书
毕业论文致谢部分怎么写
2015/05/14 职场文书
盲山观后感
2015/06/11 职场文书
运动会致辞稿
2015/07/29 职场文书
学习社交礼仪心得体会
2016/01/22 职场文书
你会写报告?产品体验报告到底该怎么写?
2019/08/14 职场文书
JAVA API 实用类 String详解
2021/10/05 Java/Android
利用For循环遍历Python字典的三种方法实例
2022/03/25 Python