pytorch加载预训练模型与自己模型不匹配的解决方案


Posted in Python onMay 13, 2021

pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码

model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~

#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x)  
        x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是

x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步

x = torch.flatten(x, 1)

补上就好了!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python数据结构之Array用法实例
Oct 09 Python
Python制作爬虫采集小说
Oct 25 Python
Linux上安装Python的PIL和Pillow库处理图片的实例教程
Jun 23 Python
简单谈谈Python中的元祖(Tuple)和字典(Dict)
Apr 21 Python
Python 通过URL打开图片实例详解
Jun 01 Python
Python中音频处理库pydub的使用教程
Jun 07 Python
python Socket之客户端和服务端握手详解
Sep 18 Python
python 内置模块详解
Jan 01 Python
对Python3 解析html的几种操作方式小结
Feb 16 Python
Python语言检测模块langid和langdetect的使用实例
Feb 19 Python
pycharm 批量修改变量名称的方法
Aug 01 Python
Python打包工具PyInstaller的安装与pycharm配置支持PyInstaller详细方法
Feb 27 Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
Python机器学习算法之决策树算法的实现与优缺点
Python爬虫基础之爬虫的分类知识总结
pytorch中的numel函数用法说明
May 13 #Python
pytorch损失反向传播后梯度为none的问题
如何使用Python实现一个简易的ORM模型
May 12 #Python
You might like
php类常量的使用详解
2013/06/08 PHP
php实现mysql数据库操作类分享
2014/02/14 PHP
Yii2主题(Theme)用法详解
2016/07/23 PHP
js控制淡入淡出示例代码
2013/11/12 Javascript
js如何判断用户是在PC端和还是移动端访问
2014/04/24 Javascript
js文件包含的几种方式介绍
2014/09/28 Javascript
Javascript中的默认参数详解
2014/10/22 Javascript
基于JavaScript实现生成名片、链接等二维码
2015/09/20 Javascript
jquery实现点击弹出可放大居中及关闭的对话框(附demo源码下载)
2016/05/10 Javascript
jquery html动态添加的元素绑定事件详解
2016/05/24 Javascript
Angular2学习笔记——详解路由器模型(Router)
2016/12/02 Javascript
js实现数字递增特效【仿支付宝我的财富】
2017/05/05 Javascript
Bootstrap输入框组件使用详解
2017/06/09 Javascript
angularJs的ng-class切换class
2017/06/23 Javascript
使用原生js+canvas实现模拟心电图的实例
2017/09/20 Javascript
前端vue如何使用高德地图
2020/11/05 Javascript
python多进程操作实例
2014/11/21 Python
开源软件包和环境管理系统Anaconda的安装使用
2017/09/04 Python
Python决策树之基于信息增益的特征选择示例
2018/06/25 Python
Python读取mat文件,并转为csv文件的实例
2018/07/04 Python
python去重,一个由dict组成的list的去重示例
2019/01/21 Python
python自动发邮件总结及实例说明【推荐】
2019/05/31 Python
通过cmd进入python的实例操作
2019/06/26 Python
python求加权平均值的实例(附纯python写法)
2019/08/22 Python
Python3 无重复字符的最长子串的实现
2019/10/08 Python
Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解
2020/02/11 Python
python eventlet绿化和patch原理
2020/11/21 Python
css3实现圆锥渐变conic-gradient效果
2020/02/12 HTML / CSS
Keds加拿大官网:购买帆布运动鞋和皮鞋
2019/09/26 全球购物
char型变量中能不能存贮一个中文汉字
2015/07/08 面试题
春节请假条
2014/04/11 职场文书
应届生找工作求职信
2014/06/24 职场文书
股东授权委托书范本
2014/09/13 职场文书
邀请函怎么写
2015/01/30 职场文书
2015年领导干部廉洁自律工作总结
2015/05/26 职场文书
MySQL去除密码登录告警的方法
2022/04/20 MySQL