Pytorch之保存读取模型实例


Posted in Python onDecember 30, 2019

pytorch保存数据

pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。

# 保存模型示例代码
print('===> Saving models...')
state = {
  'state': model.state_dict(),
  'epoch': epoch          # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
  os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')

保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。

pytorch读取数据

pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。

下方的代码和上方的保存代码可以搭配使用。

print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
  try:
    checkpoint = torch.load('./checkpoint/autoencoder.t7')
    model.load_state_dict(checkpoint['state'])    # 从字典中依次读取
    start_epoch = checkpoint['epoch']
    print('===> Load last checkpoint data')
  except FileNotFoundError:
    print('Can\'t found autoencoder.t7')
else:
  start_epoch = 0
  print('===> Start from scratch')

以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。

def vgg19(pretrained=False, **kwargs):
  """VGG 19-layer model (configuration "E")
 
  Args:
    pretrained (bool): If True, returns a model pre-trained on ImageNet
  """
  model = VGG(make_layers(cfg['E']), **kwargs)
  if pretrained:
    model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
  return model

假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:

model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))

也就是pytorch的读取函数进行读取即可。

以上这篇Pytorch之保存读取模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
编写Python脚本使得web页面上的代码高亮显示
Apr 24 Python
python中日志logging模块的性能及多进程详解
Jul 18 Python
python中类和实例如何绑定属性与方法示例详解
Aug 18 Python
Mac中Python 3环境下安装scrapy的方法教程
Oct 26 Python
详解Python Matplot中文显示完美解决方案
Mar 07 Python
详解python 利用echarts画地图(热力图)(世界地图,省市地图,区县地图)
Aug 06 Python
在django中实现页面倒数几秒后自动跳转的例子
Aug 16 Python
详解用Python调用百度地图正/逆地理编码API
Jul 02 Python
opencv 阈值分割的具体使用
Jul 08 Python
10款最佳Python开发工具推荐,每一款都是神器
Oct 15 Python
python两种获取剪贴板内容的方法
Nov 06 Python
Python+Appium自动化测试的实战
Jun 30 Python
Python爬虫解析网页的4种方式实例及原理解析
Dec 30 #Python
Python中如何将一个类方法变为多个方法
Dec 30 #Python
pytorch 实现打印模型的参数值
Dec 30 #Python
Python如何基于smtplib发不同格式的邮件
Dec 30 #Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 #Python
Python类反射机制使用实例解析
Dec 30 #Python
Python读取YAML文件过程详解
Dec 30 #Python
You might like
基于PHPExcel的常用方法总结
2013/06/13 PHP
jQuery+PHP+ajax实现微博加载更多内容列表功能
2014/06/27 PHP
php格式化日期实例分析
2014/11/12 PHP
thinkphp3.2.2前后台公用类架构问题分析
2014/11/25 PHP
PHP利用超级全局变量$_GET来接收表单数据的实例
2016/11/05 PHP
php封装的验证码类分享
2017/02/26 PHP
PHP中迭代器的简单实现及Yii框架中的迭代器实现方法示例
2020/04/26 PHP
javascript操作cookie的文章(设置,删除cookies)
2010/04/01 Javascript
javascript实现des解密加密全过程
2014/04/03 Javascript
javascript实现多栏闭合展开式广告位菜单效果实例
2015/08/05 Javascript
JS实现简单的二维矩阵乘积运算
2016/01/26 Javascript
基于js对象,操作属性、方法详解
2016/08/11 Javascript
前端弹出对话框 js实现ajax交互
2016/09/09 Javascript
浅谈DOM的操作以及性能优化问题-重绘重排
2017/01/08 Javascript
原生js实现对Ajax的封装(仿jquery)
2017/01/22 Javascript
jquery实时获取时间的简单实例
2017/01/26 Javascript
Vue监听数据对象变化源码
2017/03/09 Javascript
js实现鼠标拖拽多选功能示例
2017/08/01 Javascript
JavaScript数据结构之单链表和循环链表
2017/11/28 Javascript
判断jQuery是否加载完成,没完成继续判断的解决方法
2017/12/06 jQuery
axios发送post请求,提交图片类型表单数据方法
2018/03/16 Javascript
微信小程序自定义toast组件的方法详解【含动画】
2019/05/11 Javascript
Vue指令之 v-cloak、v-text、v-html实例详解
2019/08/08 Javascript
JS实现简单打字测试
2020/06/24 Javascript
wxpython中利用线程防止假死的实现方法
2014/08/11 Python
Python 自动刷博客浏览量实例代码
2017/06/14 Python
Python语言检测模块langid和langdetect的使用实例
2019/02/19 Python
Django认证系统实现的web页面实现代码
2019/08/12 Python
matplotlib.pyplot画图并导出保存的实例
2019/12/07 Python
django实现模型字段动态choice的操作
2020/04/01 Python
浅析Python 抽象工厂模式的优缺点
2020/07/13 Python
Bergfreunde丹麦:登山装备网上零售商
2017/02/26 全球购物
儿媳婚宴答谢词
2014/01/14 职场文书
社区健康教育工作方案
2014/06/03 职场文书
Java由浅入深通关抽象类与接口(下篇)
2022/04/26 Java/Android
js前端面试常见浏览器缓存强缓存及协商缓存实例
2022/06/21 Javascript