解决pytorch 的state_dict()拷贝问题


Posted in Python onMarch 03, 2021

先说结论

model.state_dict()是浅拷贝,返回的参数仍然会随着网络的训练而变化。

应该使用deepcopy(model.state_dict()),或将参数及时序列化到硬盘。

再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。

补充:pytorch中state_dict的理解

在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

其实看了如下代码的输出应该就懂了

import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
  def __init__(self):
    super(TheModelClass, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
  print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
  print(var_name, "\t", optimizer.state_dict()[var_name])

输出如下:

Model's state_dict:
conv1.weight  torch.Size([6, 3, 5, 5])
conv1.bias  torch.Size([6])
conv2.weight  torch.Size([16, 6, 5, 5])
conv2.bias  torch.Size([16])
fc1.weight  torch.Size([120, 400])
fc1.bias  torch.Size([120])
fc2.weight  torch.Size([84, 120])
fc2.bias  torch.Size([84])
fc3.weight  torch.Size([10, 84])
fc3.bias  torch.Size([10])
Optimizer's state_dict:
state  {}
param_groups  [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]

我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!

补充:pytorch保存模型时报错***object has no attribute 'state_dict'

定义了一个类BaseNet并实例化该类:

net=BaseNet()

保存net时报错 object has no attribute 'state_dict'

torch.save(net.state_dict(), models_dir)

原因是定义类的时候不是继承nn.Module类,比如:

class BaseNet(object):
  def __init__(self):

把类定义改为

class BaseNet(nn.Module):
  def __init__(self):
    super(BaseNet, self).__init__()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
解决Python中由于logging模块误用导致的内存泄露
Apr 23 Python
Python模拟百度登录实例详解
Jan 20 Python
Python 爬虫学习笔记之单线程爬虫
Sep 21 Python
详解用Python处理HTML转义字符的5种方式
Dec 27 Python
scrapy spider的几种爬取方式实例代码
Jan 25 Python
python之从文件读取数据到list的实例讲解
Apr 19 Python
python中virtualenvwrapper安装与使用
May 20 Python
对Python的交互模式和直接运行.py文件的区别详解
Jun 29 Python
PyTorch实现更新部分网络,其他不更新
Dec 31 Python
pandas使用之宽表变窄表的实现
Apr 12 Python
pyspark 随机森林的实现
Apr 24 Python
基于Python采集爬取微信公众号历史数据
Nov 27 Python
解决pytorch 保存模型遇到的问题
Mar 03 #Python
解决pytorch 模型复制的一些问题
Mar 03 #Python
Pytorch模型迁移和迁移学习,导入部分模型参数的操作
Mar 03 #Python
pytorch 实现L2和L1正则化regularization的操作
Mar 03 #Python
Pytorch自定义Dataset和DataLoader去除不存在和空数据的操作
Mar 03 #Python
python爬取youtube视频的示例代码
Mar 03 #Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
Mar 03 #Python
You might like
关于PHP中的Class的几点个人看法
2006/10/09 PHP
Email+URL的判断和自动转换函数
2006/10/09 PHP
php简单定时执行任务的实现方法
2015/02/23 PHP
JQueryEasyUI Layout布局框架的使用
2013/04/08 Javascript
判断JS对象是否拥有某种属性的两种方式
2013/12/02 Javascript
jQuery html()方法使用不了无法显示内容的问题
2014/08/06 Javascript
浅谈javascript 函数属性和方法
2015/01/21 Javascript
js实现格式化金额,字符,时间的方法
2015/02/26 Javascript
javascript实现跨域的方法汇总
2015/06/25 Javascript
JS导出PDF插件的方法(支持中文、图片使用路径)
2016/07/12 Javascript
vue.js初学入门教程(2)
2016/11/07 Javascript
如何使用vuejs实现更好的Form validation?
2017/04/07 Javascript
获取当前按钮或者html的ID名称实例(推荐)
2017/06/23 Javascript
Vue2.0用户权限控制解决方案的示例
2018/02/10 Javascript
微信头像地址失效踩坑记附带解决方案
2019/09/23 Javascript
vue实现图片懒加载的方法分析
2020/02/05 Javascript
vue-quill-editor插入图片路径太长问题解决方法
2021/01/08 Vue.js
[03:22]DSPL第一期精彩集锦:酷炫到底!
2014/11/07 DOTA
编写简单的Python程序来判断文本的语种
2015/04/07 Python
Python使用xlrd模块操作Excel数据导入的方法
2015/05/26 Python
python安装numpy&安装matplotlib& scipy的教程
2017/11/02 Python
Python cookbook(数据结构与算法)实现查找两个字典相同点的方法
2018/02/18 Python
Numpy数组转置的两种实现方法
2018/04/17 Python
Python类的继承用法示例
2019/01/31 Python
使用OpCode绕过Python沙箱的方法详解
2019/09/03 Python
python抓取多种类型的页面方法实例
2019/11/20 Python
Node.js 和 Python之间该选择哪个?
2020/08/05 Python
python将字典内容写入json文件的实例代码
2020/08/12 Python
Linux管理员面试题 Linux admin interview questions
2014/11/01 面试题
新闻学专业应届生求职信
2013/11/08 职场文书
简历自我评价怎么写呢?
2014/01/06 职场文书
工作推荐信范文
2014/05/10 职场文书
政协工作总结2015
2015/05/20 职场文书
法制教育观后感
2015/06/17 职场文书
2019年二手房买卖合同范本
2019/10/14 职场文书
Python如何让字典保持有序排列
2022/04/29 Python