Pytorch之保存读取模型实例


Posted in Python onDecember 30, 2019

pytorch保存数据

pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。

# 保存模型示例代码
print('===> Saving models...')
state = {
  'state': model.state_dict(),
  'epoch': epoch          # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
  os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。

pytorch读取数据

pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。

下方的代码和上方的保存代码可以搭配使用。

print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
  try:
    checkpoint = torch.load('./checkpoint/autoencoder.t7')
    model.load_state_dict(checkpoint['state'])    # 从字典中依次读取
    start_epoch = checkpoint['epoch']
    print('===> Load last checkpoint data')
  except FileNotFoundError:
    print('Can\'t found autoencoder.t7')
else:
  start_epoch = 0
  print('===> Start from scratch')

以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。

def vgg19(pretrained=False, **kwargs):
  """VGG 19-layer model (configuration "E")
 
  Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = VGG(make_layers(cfg['E']), **kwargs)
  if pretrained:
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
  return model

假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:

model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))

也就是pytorch的读取函数进行读取即可。

以上这篇Pytorch之保存读取模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中捕捉详细异常信息的代码示例
Sep 18 Python
Python简单实现查找一个字符串中最长不重复子串的方法
Mar 26 Python
对python中使用requests模块参数编码的不同处理方法
May 18 Python
Python利用公共键如何对字典列表进行排序详解
May 19 Python
python实现抖音视频批量下载
Jun 20 Python
Python 查找list中的某个元素的所有的下标方法
Jun 27 Python
django做form表单的数据验证过程详解
Jul 26 Python
正则给header的冒号两边参数添加单引号(Python请求用)
Aug 09 Python
使用OpCode绕过Python沙箱的方法详解
Sep 03 Python
对pytorch的函数中的group参数的作用介绍
Feb 18 Python
python 邮件检测工具mmpi的使用
Jan 04 Python
numpy数据类型dtype转换实现
Apr 24 Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
Python类反射机制使用实例解析
Dec 30 #Python
Python读取YAML文件过程详解
Dec 30 #Python
You might like
从Web查询数据库之PHP与MySQL篇
2009/09/25 PHP
Linux编译升级php的详细方法
2013/11/04 PHP
php可应用于面包屑导航的迭代寻找家谱树实现方法
2015/02/02 PHP
PHP机器学习库php-ml的简单测试和使用方法
2017/07/14 PHP
Javascript MD4
2006/12/20 Javascript
jquery中dom操作和事件的实例学习 仿yahoo邮箱登录框的提示效果
2011/11/30 Javascript
Jquery绑定事件(bind和live的区别介绍)
2013/08/23 Javascript
将json当数据库一样操作的javascript lib
2013/10/28 Javascript
JavaScript中的常见问题解决方法(乱码,IE缓存,代理)
2013/11/28 Javascript
jquery实现点击弹出层效果的简单实例
2014/03/03 Javascript
浅谈类似于(function(){}).call()的js语句
2015/03/30 Javascript
jquery插件uploadify实现带进度条的文件批量上传
2015/12/13 Javascript
基于gulp合并压缩Seajs模块的方式说明
2016/06/14 Javascript
JavaScript中cookie工具函数封装的示例代码
2016/10/11 Javascript
Javascript 实现简单计算器实例代码
2016/10/23 Javascript
Ajax跨域实现代码(后台jsp)
2017/01/21 Javascript
微信小程序实现带刻度尺滑块功能
2017/03/29 Javascript
基于canvas实现手写签名(vue)
2020/05/21 Javascript
openlayers4实现点动态扩散
2020/08/17 Javascript
一些常用的Python爬虫技巧汇总
2016/09/28 Python
详解pandas库pd.read_excel操作读取excel文件参数整理与实例
2019/02/17 Python
python读取Excel表格文件的方法
2019/09/02 Python
windows下python安装pip方法详解
2020/02/10 Python
将pytorch转成longtensor的简单方法
2020/02/18 Python
Python 如何对文件目录操作
2020/07/10 Python
python使用列表的最佳方案
2020/08/12 Python
Python 多进程原理及实现
2020/12/21 Python
生产车间实习自我鉴定
2013/09/23 职场文书
班主任经验交流会主持词
2014/04/01 职场文书
工业设计专业自荐书
2014/06/05 职场文书
群众路线查摆问题整改措施
2014/10/10 职场文书
公司人事管理制度
2015/08/05 职场文书
任命书格式模板
2015/09/22 职场文书
Python虚拟环境virtualenv是如何使用的
2021/06/20 Python
为什么代码规范要求SQL语句不要过多的join
2021/06/23 MySQL
Python中with上下文管理协议的作用及用法
2022/03/18 Python