解决Pytorch修改预训练模型时遇到key不匹配的情况


Posted in Python onJune 05, 2021

一、Pytorch修改预训练模型时遇到key不匹配

最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。

在我使用新赋值的网络模型时出现了key不匹配的问题

#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights) 
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
    ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
    net = ssd_net
    ...
    if args.resume:
        ...
    else:
        base_weights = torch.load(args.save_folder + args.basenet)
        #args.basenet为ssd_base.pth
        print('Loading base network...')
        ssd_net.vgg.load_state_dict(base_weights)

此时会如下出错误:

Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)

RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.

说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”

我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。

现在的问题是因为自己定义保存的模型key参数多了一个前缀。

可以通过如下语句进行修改,并加载

from collections import OrderedDict   #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**  
for k, v in base_weights.items():
    name = k[4:]   # remove `vgg.`,即只取vgg.0.weights的后面几位
    new_state_dict[name] = v 
    ssd_net.vgg.load_state_dict(new_state_dict)

此时就不会再出错了。

参考了这个篇。修改一下就可以应用到自己的模型啦。

//www.3water.com/article/214214.htm

二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘

最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。

KeyError: 'layer1.0.bn1.num_batches_tracked'

其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,

这个参数的作用如下:

训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1

如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.

所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.

有问题的代码:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        for i in state_dict:
            key = param_name + '.' + i
            state_dict[i].copy_(param_dict[key])
        del param_dict

对'num_batches_tracked进行过滤:

def load_specific_param(self, state_dict, param_name, model_path):
        param_dict = torch.load(model_path)
        param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
        for i in state_dict:
            key = param_name + '.' + i
            if 'num_batches_tracked' in key:
                continue
            state_dict[i].copy_(param_dict[key])
        del param_dict

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取当前计算机cpu数量的方法
Apr 18 Python
使用Python求解最大公约数的实现方法
Aug 20 Python
Python实现曲线点抽稀算法的示例
Oct 12 Python
Flask和Django框架中自定义模型类的表名、父类相关问题分析
Jul 19 Python
举例讲解Python常用模块
Mar 08 Python
Python 日期的转换及计算的具体使用详解
Jan 16 Python
pandas分组聚合详解
Apr 10 Python
Keras之自定义损失(loss)函数用法说明
Jun 10 Python
使用python脚本自动生成K8S-YAML的方法示例
Jul 12 Python
实现Python3数组旋转的3种算法实例
Sep 16 Python
使用python操作lmdb对数据读取的实例
Dec 11 Python
Python中文分词库jieba(结巴分词)详细使用介绍
Apr 07 Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
You might like
PHP Pear 安装及使用
2009/03/19 PHP
tp5.1 框架数据库-数据集操作实例分析
2020/05/26 PHP
PHP 实现重载
2021/03/09 PHP
JavaScript实现拼音排序的方法
2012/11/20 Javascript
简单谈谈javascript中的变量、作用域和内存问题
2015/08/30 Javascript
封装获取dom元素的简单实例
2016/07/08 Javascript
jQuery实现磁力图片跟随效果完整示例
2016/09/16 Javascript
jquery插件bootstrapValidator表单验证详解
2016/12/15 Javascript
关于vue中watch检测到不到对象属性的变化的解决方法
2018/02/08 Javascript
jquery点击回车键实现登录效果并默认焦点的方法
2018/03/09 jQuery
vue中使用heatmapjs的示例代码(结合百度地图)
2018/09/05 Javascript
js+html实现周岁年龄计算器
2019/06/25 Javascript
VScode格式化ESlint方法(最全最好用方法)
2019/09/10 Javascript
使用layui监听器监听select下拉框,事件绑定不成功的解决方法
2019/09/28 Javascript
Vue实现省市区三级联动
2020/12/27 Vue.js
[02:05]2014DOTA2西雅图国际邀请赛 BBC第二天小组赛总结
2014/07/11 DOTA
pycharm 使用心得(六)进行简单的数据库管理
2014/06/06 Python
Python实现的几个常用排序算法实例
2014/06/16 Python
轻松掌握python设计模式之策略模式
2016/11/18 Python
Python的语言类型(详解)
2017/06/24 Python
Python排序搜索基本算法之归并排序实例分析
2017/12/08 Python
python实现数据分析与建模
2019/07/11 Python
详解Python 循环嵌套
2020/07/09 Python
Python实例方法、类方法、静态方法区别详解
2020/09/05 Python
python归并排序算法过程实例讲解
2020/11/04 Python
HTML5地理定位与第三方工具百度地图的应用
2016/11/17 HTML / CSS
Marriott中国:万豪国际酒店查询预订
2016/09/02 全球购物
CLR与IL分别是什么含义
2016/08/23 面试题
毕业生简历自我评价范文
2014/04/09 职场文书
法律顾问服务方案
2014/05/15 职场文书
授权委托书(法人单位用)
2014/09/29 职场文书
物业保洁员岗位职责
2015/02/13 职场文书
2019请假条的基本格式及范文!
2019/07/05 职场文书
vue响应式原理与双向数据的深入解析
2021/06/04 Vue.js
OpenCV-Python直方图均衡化实现图像去雾
2021/06/07 Python
Python代码实现双链表
2022/05/25 Python