pytorch 状态字典:state_dict使用详解


Posted in Python onJanuary 17, 2020

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

备注:

1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.

模态字典(state_dict)的保存(model是一个网络结构类的对象)

1.1)仅保存学习到的参数,用以下命令

torch.save(model.state_dict(), PATH)

1.2)加载model.state_dict,用以下命令

model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名

2.1)保存整个model的状态,用以下命令

torch.save(model,PATH)

2.2)加载整个model的状态,用以下命令:

# Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项

如何仅加载某一层的训练的到的参数(某一层的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作对象是tensor.

疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:经测试,不可以.model.conv1 没有requires_grad属性.

全部测试代码:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上这篇pytorch 状态字典:state_dict使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的元类编程入门指引
Apr 15 Python
Python functools模块学习总结
May 09 Python
Python中运算符"=="和"is"的详解
Oct 08 Python
用python处理MS Word的实例讲解
May 08 Python
Python 普通最小二乘法(OLS)进行多项式拟合的方法
Dec 29 Python
python根据txt文本批量创建文件夹
Dec 08 Python
Python爬虫入门有哪些基础知识点
Jun 02 Python
keras 多任务多loss实例
Jun 22 Python
keras的ImageDataGenerator和flow()的用法说明
Jul 03 Python
Python安装Bs4的多种方法
Nov 28 Python
opencv用VS2013调试时用Image Watch插件查看图片
Jul 26 Python
python代码实现扫码关注公众号登录的实战
Nov 01 Python
Python标准库itertools的使用方法
Jan 17 #Python
Python实现投影法分割图像示例(二)
Jan 17 #Python
Python常用库大全及简要说明
Jan 17 #Python
Python Sphinx使用实例及问题解决
Jan 17 #Python
通过实例了解Python str()和repr()的区别
Jan 17 #Python
python无序链表删除重复项的方法
Jan 17 #Python
Python实现投影法分割图像示例(一)
Jan 17 #Python
You might like
PHP 中的一些经验积累
2006/10/09 PHP
PHP 页面跳转到另一个页面的多种方法方法总结
2009/07/07 PHP
php后台多用户权限组思路与实现程序代码分享
2012/02/13 PHP
全面解读PHP的Yii框架中的日志功能
2016/03/17 PHP
PHP中的多种加密技术及代码示例解析
2016/10/20 PHP
Zend Framework校验器Zend_Validate用法详解
2016/12/09 PHP
Laravel 默认邮箱登录改成用户名登录的实现方法
2019/08/12 PHP
JavaScript 工具库 Cloudgamer JavaScript Library v0.1 发布
2009/10/29 Javascript
innerText和textContent对比及使用介绍
2013/02/27 Javascript
JQuery操作iframe父页面与子页面的元素与方法(实例讲解)
2013/11/20 Javascript
JQuery的Ajax请求实现局部刷新的简单实例
2014/02/11 Javascript
超炫的jquery仿flash导航栏特效
2014/11/11 Javascript
javaScript基础语法介绍
2015/02/28 Javascript
JS将滑动门改为选项卡(需鼠标点击)的实现方法
2015/09/27 Javascript
jQuery实现文字超过1行、2行或规定的行数时自动加省略号的方法
2018/03/28 jQuery
通过vue-router懒加载解决首次加载时资源过多导致的速度缓慢问题
2018/04/08 Javascript
vue如何将v-for中的表格导出来
2018/05/07 Javascript
jQuery实现的简单手风琴效果示例
2018/08/29 jQuery
详解vuex持久化插件解决浏览器刷新数据消失问题
2019/04/15 Javascript
基于Node.js的大文件分片上传示例
2019/06/19 Javascript
详解vue beforeRouteEnter 异步获取数据给实例问题
2019/08/09 Javascript
python中字符串前面加r的作用
2015/06/04 Python
详解python进行mp3格式判断
2016/12/23 Python
Python实现的txt文件去重功能示例
2018/07/07 Python
详解Python学习之安装pandas
2019/04/16 Python
在PyCharm中控制台输出日志分层级分颜色显示的方法
2019/07/11 Python
Python二维码生成识别实例详解
2019/07/16 Python
python爬虫selenium和phantomJs使用方法解析
2019/08/08 Python
Python filter过滤器原理及实例应用
2020/08/18 Python
Python+Appium实现自动化清理微信僵尸好友的方法
2021/02/04 Python
CSS3实现的闪烁跳跃进度条示例(附源码)
2013/08/19 HTML / CSS
船餐厅和泰晤士河餐饮游轮:Bateaux London
2018/03/19 全球购物
俄罗斯厨房产品购物网站:COOK HOUSE
2021/03/15 全球购物
行政监察建议书
2014/05/19 职场文书
2015年员工试用期工作总结
2014/12/12 职场文书
2014年学生管理工作总结
2014/12/20 职场文书