pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用setup.py安装python包和卸载python包的方法
Nov 27 Python
Python查找函数f(x)=0根的解决方法
May 07 Python
Pycharm学习教程(5) Python快捷键相关设置
May 03 Python
对Python进行数据分析_关于Package的安装问题
May 22 Python
Python实现可自定义大小的截屏功能
Jan 20 Python
Python多线程扫描端口代码示例
Feb 09 Python
pandas按若干个列的组合条件筛选数据的方法
Apr 11 Python
Python一个简单的通信程序(客户端 服务器)
Mar 06 Python
解决python tkinter界面卡死的问题
Jul 17 Python
推荐8款常用的Python GUI图形界面开发框架
Feb 23 Python
Python venv虚拟环境配置过程解析
Jul 08 Python
Python实例方法、类方法、静态方法区别详解
Sep 05 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
PHP 加密解密内部算法
2010/04/22 PHP
PHP将HTML转换成文本的实现代码
2015/01/21 PHP
PHP函数用法详解【初始化、嵌套、内置函数等】
2020/06/02 PHP
模拟jQuery ajax服务器端与客户端通信的代码
2011/03/28 Javascript
jQuery EasyUI API 中文文档 可调整尺寸
2011/09/29 Javascript
jquery win 7透明弹出层效果的简单代码
2013/08/06 Javascript
js键盘上下左右键怎么触发function(实例讲解)
2013/12/14 Javascript
javascript 获取函数形参个数
2014/07/31 Javascript
jquery地址栏链接与a标签链接匹配之特效代码总结
2015/08/24 Javascript
js实现的星星评分功能函数
2015/12/09 Javascript
JSON 的正确用法探讨:Pyhong、MongoDB、JavaScript与Ajax
2016/05/15 Javascript
JavaScript基于自定义函数判断变量类型的实现方法
2016/11/23 Javascript
解决vue接口数据赋值给data没有反应的问题
2018/08/27 Javascript
JavaScript实现的鼠标跟随特效示例【2则实例】
2018/12/22 Javascript
微信公众平台 客服接口发消息的实现代码(Java接口开发)
2019/04/17 Javascript
layui table表格数据的新增,修改,删除,查询,双击获取行数据方式
2019/11/14 Javascript
JavaScript实现轮播图效果
2020/10/30 Javascript
[31:33]2014 DOTA2国际邀请赛中国区预选赛 TongFu VS DT 第一场
2014/05/23 DOTA
使用Python实现简单的服务器功能
2017/08/25 Python
python使用tensorflow保存、加载和使用模型的方法
2018/01/31 Python
详解Python 定时框架 Apscheduler原理及安装过程
2019/06/14 Python
python使用装饰器作日志处理的方法
2019/07/11 Python
python try except返回异常的信息字符串代码实例
2019/08/15 Python
将python依赖包打包成window下可执行文件bat方式
2019/12/26 Python
python 爬取马蜂窝景点翻页文字评论的实现
2020/01/20 Python
关于tensorflow softmax函数用法解析
2020/06/30 Python
Python实现一个简单的递归下降分析器
2020/08/01 Python
Django执行源生mysql语句实现过程解析
2020/11/12 Python
洛杉矶生活休闲而精致的基础品牌:Mika Jaymes
2018/01/07 全球购物
俄罗斯外国汽车和国产汽车配件网上商店:Движком
2020/04/19 全球购物
学生党员思想汇报
2013/12/28 职场文书
财务总监管理岗位职责
2014/03/08 职场文书
《梅花魂》教学反思
2014/04/30 职场文书
五水共治一句话承诺
2014/05/30 职场文书
商业项目策划方案
2014/06/05 职场文书
2014年秋季开学典礼主持词
2014/08/02 职场文书