pytorch加载预训练模型与自己模型不匹配的解决方案


Posted in Python onMay 13, 2021

pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。

两个有序字典找不同

模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        err = 1

自己搭建模型的注意事项

搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。

model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
    if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
        continue
    model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)

完整的代码见自己搭建resnet18网络并加载torchvision自带权重

新增的改进代码

model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
    if m >= len1 or n >= len2:
        break
    layername1, layername2 = model_list1[m], model_list2[n]
    w1, w2 = model_dict1[layername1], model_dict2[layername2]
    if w1.shape != w2.shape:
        continue
    model_dict2[layername2] = model_dict1[layername1]
    m += 1
    n += 1
model.load_state_dict(model_dict2)

如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。

补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配

看代码吧~

#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
    def __init__(self, rgb_range):
        super(Classification_att, self).__init__()
        self.vgg19 =models.vgg19(pretrained=True)
        vgg = models.vgg19(pretrained=True).features
        conv_modules = [m for m in vgg]
        self.vgg_conv = nn.Sequential(*conv_modules[:37])
        classfi = models.vgg19(pretrained=True).classifier
        classif_modules = [n for n in classfi]
        self.vgg_class = nn.Sequential(*classif_modules[:4])
        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.vgg_conv.parameters():
            p.requires_grad = False
        for p in self.vgg_class.parameters():
            p.requires_grad = False
        self.classifi = nn.Sequential(
            nn.Linear(4096, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 256),
            nn.ReLU(True),
            nn.Linear(256, 64),
        )
 
    def forward(self, x):
        x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear', 
        align_corners=False)
        x = self.sub_mean(x)
        x = self.vgg_conv(x)  
        x = self.vgg_class(x)  #执行这部报错,说张量不匹配

原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的

查看vgg的pytorch源码发现是

x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步

所以自己的少了一步

x = torch.flatten(x, 1)

补上就好了!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python递归计算N!的方法
May 05 Python
Python fileinput模块使用实例
Jun 03 Python
Python向日志输出中添加上下文信息
May 24 Python
Python读取本地文件并解析网页元素的方法
May 21 Python
详解python Todo清单实战
Nov 01 Python
python遍历小写英文字母的方法
Jan 02 Python
Python3实现的判断环形链表算法示例
Mar 07 Python
如何通过python画loss曲线的方法
Jun 26 Python
基于Python的微信机器人开发 微信登录和获取好友列表实现解析
Aug 21 Python
python简单实现9宫格图片实例
Sep 03 Python
使用python操作lmdb对数据读取的实例
Dec 11 Python
python爬取企查查企业信息之selenium自动模拟登录企查查
Apr 08 Python
Python数据分析入门之教你怎么搭建环境
Pytorch 统计模型参数量的操作 param.numel()
May 13 #Python
Python机器学习算法之决策树算法的实现与优缺点
Python爬虫基础之爬虫的分类知识总结
pytorch中的numel函数用法说明
May 13 #Python
pytorch损失反向传播后梯度为none的问题
如何使用Python实现一个简易的ORM模型
May 12 #Python
You might like
php读取html并截取字符串的简单代码
2009/11/30 PHP
PHP运行时强制显示出错信息的代码
2011/04/20 PHP
PHP大小写问题:函数名和类名不区分,变量名区分
2013/06/17 PHP
PHP使用内置函数生成图片的方法详解
2016/05/09 PHP
PHP安装GeoIP扩展根据IP获取地理位置及计算距离的方法
2016/07/01 PHP
PHP简单实现上一页下一页功能示例
2016/09/14 PHP
thinkPHP事务操作简单案例分析
2019/10/17 PHP
Jquery之Bind方法参数传递与接收的三种方法
2014/06/24 Javascript
jQuery的层级查找方式分析
2016/06/16 Javascript
js实现登录验证码
2016/12/22 Javascript
AngularJs中 ng-repeat指令中实现含有自定义指令的动态html的方法
2017/01/19 Javascript
JavaScript函数节流的两种写法
2017/04/07 Javascript
AngularJS中的拦截器实例详解
2017/04/07 Javascript
vue使用自定义指令实现拖拽
2021/01/29 Javascript
nodejs实现聊天机器人功能
2019/09/19 NodeJs
node.js使用stream模块实现自定义流示例
2020/02/13 Javascript
[36:19]2018DOTA2亚洲邀请赛 小组赛 A组加赛 Newbee vs LGD
2018/04/03 DOTA
一篇不错的Python入门教程
2007/02/08 Python
Python程序设计入门(4)模块和包
2014/06/16 Python
python中PIL安装简单教程
2016/04/21 Python
Python全局变量用法实例分析
2016/07/19 Python
python之文件的读写和文件目录以及文件夹的操作实现代码
2016/08/28 Python
tensorflow学习笔记之简单的神经网络训练和测试
2018/04/15 Python
pytorch + visdom CNN处理自建图片数据集的方法
2018/06/04 Python
python中的for循环
2018/09/28 Python
python操作kafka实践的示例代码
2019/06/19 Python
通过selenium抓取某东的TT购买记录并分析趋势过程解析
2019/08/15 Python
Django认证系统user对象实现过程解析
2020/03/02 Python
浅谈Python的方法解析顺序(MRO)
2020/03/05 Python
安装python3.7编译器后如何正确安装opnecv的方法详解
2020/06/16 Python
基础的CSS3弹性盒Flexbox布局使用实例
2016/04/08 HTML / CSS
解决H5的a标签的download属性下载service上的文件出现跨域问题
2019/07/16 HTML / CSS
北美大型运动类产品商城:Champs Sports
2017/01/12 全球购物
幼儿园师德演讲稿
2014/05/06 职场文书
电影圆明园观后感
2015/06/03 职场文书
Python使用openpyxl模块处理Excel文件
2022/06/05 Python