解决pytorch 的state_dict()拷贝问题


Posted in Python onMarch 03, 2021

先说结论

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:

net=BaseNet()

保存net时报错 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:

class BaseNet(object):
  def __init__(self):

把类定义改为

class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
新手常见6种的python报错及解决方法
Mar 09 Python
python主线程捕获子线程的方法
Jun 17 Python
python保存文件方法小结
Jul 27 Python
对Python3中bytes和HexStr之间的转换详解
Dec 04 Python
python requests库爬取豆瓣电视剧数据并保存到本地详解
Aug 10 Python
python通过robert、sobel、Laplace算子实现图像边缘提取详解
Aug 21 Python
python 三元运算符使用解析
Sep 16 Python
Python使用Opencv实现图像特征检测与匹配的方法
Oct 30 Python
tensorflow实现读取模型中保存的值 tf.train.NewCheckpointReader
Feb 10 Python
Python+Appium实现自动化测试的使用步骤
Mar 24 Python
pyinstaller打包单文件时--uac-admin选项不起作用怎么办
Apr 15 Python
解决python便携版无法直接运行py文件的问题
Sep 01 Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
Mar 03 #Python
python爬取youtube视频的示例代码
Mar 03 #Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
Mar 03 #Python
You might like
使用PHP数组实现无限分类,不使用数据库,不使用递归.
2006/12/09 PHP
浅析PHP微信支付通知的处理方式
2014/05/25 PHP
浅谈json_encode用法
2015/03/05 PHP
php die()与exit()的区别实例详解
2016/12/03 PHP
PHP bin2hex()函数基础实例讲解
2019/02/11 PHP
JS 参数传递的实际应用代码分析
2009/09/13 Javascript
jQuery动画animate方法使用介绍
2013/05/06 Javascript
JS验证控制输入中英文字节长度(input、textarea等)具体实例
2013/06/21 Javascript
JS获取图片实际宽高及根据图片大小进行自适应
2013/08/11 Javascript
jquery实现的下拉和收缩效果示例
2014/08/21 Javascript
Javascript检查图片大小不要让大图片撑破页面
2014/11/04 Javascript
跟我学习javascript的for循环和for...in循环
2015/11/18 Javascript
原生JS实现匀速图片轮播动画
2016/10/18 Javascript
js实现按座位号抽奖
2017/04/05 Javascript
Vue键盘事件用法总结
2017/04/18 Javascript
详解微信第三方小程序代开发
2017/06/23 Javascript
JavaScript动态绑定详解
2017/09/14 Javascript
vue-router 组件复用问题详解
2018/01/22 Javascript
详解如何在vue项目中引入elementUI组件
2018/02/11 Javascript
layer弹出层自定义提交取消按钮的例子
2019/09/10 Javascript
[07:26]2015国际邀请赛第二日TOP10集锦
2015/08/06 DOTA
python使用sorted函数对列表进行排序的方法
2015/04/04 Python
Python爬虫番外篇之Cookie和Session详解
2017/12/27 Python
python列表list保留顺序去重的实例
2018/12/14 Python
python实现简单成绩录入系统
2019/09/19 Python
亚马逊印度站:Amazon.in
2017/10/15 全球购物
PHP引擎php.ini参数优化深入讲解
2021/03/24 PHP
实习护士自我鉴定
2013/10/13 职场文书
银行领导证婚词
2014/01/11 职场文书
服务员自我评价
2014/01/25 职场文书
意向书范文
2014/03/31 职场文书
四年级学生评语大全
2014/04/21 职场文书
债务纠纷委托书
2014/08/30 职场文书
水利专业大学生职业生涯规划书范文
2014/09/17 职场文书
四风问题个人对照检查剖析材料
2014/09/27 职场文书
vscode中使用npm安装babel的方法
2021/08/02 Javascript