pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python脚本实现统计日志文件中的ip访问次数代码分享
Aug 06 Python
Python解析json文件相关知识学习
Mar 01 Python
深入理解NumPy简明教程---数组3(组合)
Dec 17 Python
利用python 更新ssh 远程代码 操作远程服务器的实现代码
Feb 08 Python
Flask框架配置与调试操作示例
Jul 23 Python
Python实现常见的回文字符串算法
Nov 14 Python
Django实现一对多表模型的跨表查询方法
Dec 18 Python
一篇文章弄懂Python中所有数组数据类型
Jun 23 Python
Python实现个人微信号自动监控告警的示例
Jul 03 Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 Python
python绘制雷达图实例讲解
Jan 03 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
Laravel模型间关系设置分表的方法示例
2018/04/21 PHP
浅谈PHP5.6 与 PHP7.0 区别
2019/10/09 PHP
JQuery Easyui Tree的oncheck事件实现代码
2010/05/28 Javascript
JavaScript可否多线程? 深入理解JavaScript定时机制
2012/05/23 Javascript
js屏蔽鼠标键盘(右键/Ctrl+N/Shift+F10/F11/F5刷新/退格键)
2013/01/24 Javascript
jquery网页元素拖拽插件效果及实现
2013/08/05 Javascript
node.js正则表达式获取网页中所有链接的代码实例
2014/06/03 Javascript
Jquery代码实现图片轮播效果(一)
2015/08/12 Javascript
JQuery EasyUI Layout 在from布局自适应窗口大小的实现方法
2016/05/28 Javascript
微信小程序使用第三方库Immutable.js实例详解
2016/09/27 Javascript
利用jsonp与代理服务器方案解决跨域问题
2017/09/14 Javascript
fullpage.js最后一屏滚动方式
2018/02/06 Javascript
Vue中使用vue-i18插件实现多语言切换功能
2018/04/25 Javascript
详解javascript 变量提升(Hoisting)
2019/03/12 Javascript
vue实现行列转换的一种方法
2019/08/06 Javascript
js实现跳一跳小游戏
2020/07/31 Javascript
[01:10:03]OG vs EG 2018国际邀请赛淘汰赛BO3 第三场 8.23
2018/08/24 DOTA
在Python中处理字符串之ljust()方法的使用简介
2015/05/19 Python
Python实现判断字符串中包含某个字符的判断函数示例
2018/01/08 Python
python opencv旋转图像(保持图像不被裁减)
2018/07/26 Python
python文件拆分与重组实例
2018/12/10 Python
Python实现的在特定目录下导入模块功能分析
2019/02/11 Python
利用python脚本如何简化jar操作命令
2019/02/24 Python
python图形工具turtle绘制国际象棋棋盘
2019/05/23 Python
Python3内置模块之json编解码方法小结【推荐】
2020/12/09 Python
Django框架创建mysql连接与使用示例
2019/07/29 Python
大学毕业感言50字
2014/02/07 职场文书
优秀管理者获奖感言
2014/02/17 职场文书
挂职自我鉴定
2014/02/26 职场文书
财务人员的自我评价范文
2014/03/03 职场文书
三分钟自我介绍演讲稿
2014/08/21 职场文书
俞敏洪一分钟演讲稿
2014/08/26 职场文书
餐饮服务食品安全承诺书
2015/04/29 职场文书
同意转租证明
2015/06/24 职场文书
vue中的可拖拽宽度div的实现示例
2022/04/08 Vue.js
SQL Server 中的事务介绍
2022/05/20 SQL Server