从Pytorch模型pth文件中读取参数成numpy矩阵的操作


Posted in Python onMarch 04, 2021

目的:

把训练好的pth模型参数提取出来,然后用其他方式部署到边缘设备。

Pytorch给了很方便的读取参数接口:

nn.Module.parameters()

直接看demo:

from torchvision.models.alexnet import alexnet 
model = alexnet(pretrained=True).eval().cuda()
parameters = model.parameters()
for p in parameters:
  numpy_para = p.detach().cpu().numpy()
  print(type(numpy_para))
  print(numpy_para.shape)

上面得到的numpy_para就是numpy参数了~

Note:

model.parameters()是以一个生成器的形式迭代返回每一层的参数。所以用for循环读取到各层的参数,循环次数就表示层数。

而每一层的参数都是torch.nn.parameter.Parameter类型,是Tensor的子类,所以直接用tensor转numpy(即p.detach().cpu().numpy())的方法就可以直接转成numpy矩阵。

方便又好用,爆赞~

补充:pytorch训练好的.pth模型转换为.pt

将python训练好的.pth文件转为.pt

import torch
import torchvision
from unet import UNet
model = UNet(3, 2)#自己定义的网络模型
model.load_state_dict(torch.load("best_weights.pth"))#保存的训练模型
model.eval()#切换到eval()
example = torch.rand(1, 3, 320, 480)#生成一个随机输入维度的输入
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
rhythmbox中文名乱码问题解决方法
Sep 06 Python
举例讲解Python中的迭代器、生成器与列表解析用法
Mar 20 Python
Python中super函数的用法
Nov 17 Python
浅谈Python中range和xrange的区别
Dec 20 Python
Python并发编程协程(Coroutine)之Gevent详解
Dec 27 Python
Pandas:Series和DataFrame删除指定轴上数据的方法
Nov 10 Python
Python实现简易过滤删除数字的方法小结
Jan 09 Python
Python OrderedDict的使用案例解析
Oct 25 Python
Python实现在Windows平台修改文件属性
Mar 05 Python
Python如何省略括号方法详解
Mar 21 Python
python判断是空的实例分享
Jul 06 Python
基于python实现监听Rabbitmq系统日志代码示例
Nov 28 Python
python 如何用urllib与服务端交互(发送和接收数据)
Mar 04 #Python
python 求两个向量的顺时针夹角操作
Mar 04 #Python
python 制作磁力搜索工具
Mar 04 #Python
python抢购软件/插件/脚本附完整源码
Mar 04 #Python
Python 求向量的余弦值操作
Mar 04 #Python
django使用多个数据库的方法实例
Mar 04 #Python
Python使用paramiko连接远程服务器执行Shell命令的实现
Mar 04 #Python
You might like
linux下 C语言对 php 扩展
2008/12/14 PHP
php返回json数据函数实例
2014/10/09 PHP
浅谈php自定义错误日志
2015/02/13 PHP
实现PHP+Mysql无限分类的方法汇总
2015/03/02 PHP
laravel框架与其他框架的详细对比
2019/10/23 PHP
Javascript 获取链接(url)参数的方法[正则与截取字符串]
2010/02/09 Javascript
基于Jquery 解决Ajax请求的页面 浏览器后退前进功能,页面刷新功能实效问题
2010/12/11 Javascript
jQuery简单实现日历的方法
2015/05/04 Javascript
javascript正则表达式中分组详解
2016/07/17 Javascript
Vuejs仿网易云音乐实现听歌及搜索功能
2017/03/30 Javascript
微信小程序访问node.js接口服务器搭建教程
2017/04/25 Javascript
win系统下nodejs环境安装配置
2017/05/04 NodeJs
激动人心的 Angular HttpClient的源码解析
2017/07/10 Javascript
JavaScript利用fetch实现异步请求的方法实例
2017/07/26 Javascript
select自定义小三角样式代码(实用总结)
2017/08/18 Javascript
如何让你的JS代码更好看易读
2017/12/01 Javascript
Javascript中JSON数据分组优化实践及JS操作JSON总结
2017/12/22 Javascript
基于Angular 8和Bootstrap 4实现动态主题切换的示例代码
2020/02/11 Javascript
[01:19:46]DOTA2-DPC中国联赛 正赛 SAG vs DLG BO3 第一场 2月28日
2021/03/11 DOTA
在Python中操作列表之List.pop()方法的使用
2015/05/21 Python
详解Django之admin组件的使用和源码剖析
2018/05/04 Python
python绘制热力图heatmap
2020/03/23 Python
对pandas中两种数据类型Series和DataFrame的区别详解
2018/11/12 Python
Python Gluon参数和模块命名操作教程
2019/12/18 Python
Django Admin后台添加数据库视图过程解析
2020/04/01 Python
python为什么要安装到c盘
2020/07/20 Python
CSS3支持IE6, 7, and 8的边框border属性
2012/12/28 HTML / CSS
html5组织文档结构_动力节点Java学院整理
2017/07/11 HTML / CSS
css 如何让背景图片拉伸填充避免重复显示
2013/07/11 HTML / CSS
HTML5梦幻之旅——炫丽的流星雨效果实现过程
2013/08/06 HTML / CSS
Python使用openpyxl复制整张sheet
2021/03/24 Python
实习自我鉴定模板
2013/09/28 职场文书
财务会计专业推荐信
2013/11/30 职场文书
村主任“四风”问题个人整改措施
2014/10/04 职场文书
面试中老生常谈的MySQL问答集锦夯实基础
2022/03/13 MySQL
详解Flutter和Dart取消Future的三种方法
2022/04/07 Java/Android