Pytorch中实现只导入部分模型参数的方式


Posted in Python onJanuary 02, 2020

我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?

好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。

请看下面的一个例子。

我们先搭建一个小小的网络。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我们保存这个网络的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我们现在试着导入之前保存的模型参数。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出现了没有在模型文件中找到error中的关键字的错误。

现在我们这样导入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。

为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。

以上这篇Pytorch中实现只导入部分模型参数的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python删除文件示例分享
Jan 28 Python
Python的Bottle框架中获取制定cookie的教程
Apr 24 Python
Python中with及contextlib的用法详解
Jun 08 Python
基于Linux系统中python matplotlib画图的中文显示问题的解决方法
Jun 15 Python
Python视频爬虫实现下载头条视频功能示例
May 07 Python
浅析Python四种数据类型
Sep 26 Python
Python中几种属性访问的区别与用法详解
Oct 10 Python
使用 Python 处理 JSON 格式的数据
Jul 22 Python
python中的Elasticsearch操作汇总
Oct 30 Python
Python图片的横坐标汉字实例
Dec 04 Python
Pytorch Tensor的统计属性实例讲解
Dec 30 Python
Python如何获取文件路径/目录
Sep 22 Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
python使用SQLAlchemy操作MySQL
Jan 02 #Python
pytorch 实现cross entropy损失函数计算方式
Jan 02 #Python
Matplotlib scatter绘制散点图的方法实现
Jan 02 #Python
You might like
php fsockopen解决办法 php实现多线程
2014/01/20 PHP
浅谈php中的循环while、do...while、for、foreach四种循环
2016/11/05 PHP
关于js类的定义
2011/06/28 Javascript
js 获取坐标 通过JS得到当前焦点(鼠标)的坐标属性
2013/01/04 Javascript
javascript动态控制服务器控件实例
2014/09/05 Javascript
node.js require() 源码解读
2015/12/13 Javascript
Bootstrap3制作自己的导航栏
2016/05/12 Javascript
JavaScript实现移动端滑动选择日期功能
2016/06/21 Javascript
原生js仿jquery一些常用方法(必看篇)
2016/09/20 Javascript
vue.js指令v-model使用方法
2017/03/20 Javascript
Vue服务端渲染和Vue浏览器端渲染的性能对比(实例PK )
2017/03/31 Javascript
vue之组件内监控$store中定义变量的变化详解
2019/11/08 Javascript
JS实现动态无缝轮播
2020/01/11 Javascript
Openlayers+EasyUI Tree动态实现图层控制
2020/09/28 Javascript
实现vuex原理的示例
2020/10/21 Javascript
[01:27]DOTA2电竞之夜 今夜共饮庆功酒
2014/08/02 DOTA
[42:25]EG vs Spirit Supermajor 败者组 BO3 第二场 6.4
2018/06/05 DOTA
python网络编程学习笔记(三):socket网络服务器
2014/06/09 Python
Python实现简单网页图片抓取完整代码实例
2017/12/15 Python
Python切片索引用法示例
2018/05/15 Python
Python实现的绘制三维双螺旋线图形功能示例
2018/06/23 Python
python调用百度语音REST API
2018/08/30 Python
如何搭建pytorch环境的方法步骤
2020/05/06 Python
Silk’n激光脱毛器官网:silkn.com
2016/10/06 全球购物
Tretorn美国官网:瑞典外套和鞋类品牌,抵御风雨
2018/07/19 全球购物
T3官网:头发造型工具
2019/12/26 全球购物
西安启天科技有限公司网络工程师面试题笔试题
2016/06/12 面试题
大学生优秀团员事迹材料
2014/01/30 职场文书
委托公证书
2014/04/08 职场文书
建筑专业毕业生自荐信
2014/05/25 职场文书
家长会标语
2014/06/24 职场文书
讲文明知礼仪演讲稿
2014/09/13 职场文书
乡镇组织委员个人整改措施
2014/09/16 职场文书
机票销售员态度不好检讨书
2014/09/27 职场文书
党员个人整改措施
2014/10/24 职场文书
秋季运动会加油词
2015/07/18 职场文书