从Pytorch模型pth文件中读取参数成numpy矩阵的操作


Posted in Python onMarch 04, 2021

目的:

把训练好的pth模型参数提取出来,然后用其他方式部署到边缘设备。

Pytorch给了很方便的读取参数接口:

nn.Module.parameters()

直接看demo:

from torchvision.models.alexnet import alexnet 
model = alexnet(pretrained=True).eval().cuda()
parameters = model.parameters()
for p in parameters:
  numpy_para = p.detach().cpu().numpy()
  print(type(numpy_para))
  print(numpy_para.shape)

上面得到的numpy_para就是numpy参数了~

Note:

model.parameters()是以一个生成器的形式迭代返回每一层的参数。所以用for循环读取到各层的参数,循环次数就表示层数。

而每一层的参数都是torch.nn.parameter.Parameter类型,是Tensor的子类,所以直接用tensor转numpy(即p.detach().cpu().numpy())的方法就可以直接转成numpy矩阵。

方便又好用,爆赞~

补充:pytorch训练好的.pth模型转换为.pt

将python训练好的.pth文件转为.pt

import torch
import torchvision
from unet import UNet
model = UNet(3, 2)#自己定义的网络模型
model.load_state_dict(torch.load("best_weights.pth"))#保存的训练模型
model.eval()#切换到eval()
example = torch.rand(1, 3, 320, 480)#生成一个随机输入维度的输入
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
举例详解Python中的split()函数的使用方法
Apr 07 Python
Python网络编程之TCP套接字简单用法示例
Apr 09 Python
pytorch 数据集图片显示方法
Jul 26 Python
一篇文章弄懂Python中所有数组数据类型
Jun 23 Python
python 计算平均平方误差(MSE)的实例
Jun 29 Python
pandas实现to_sql将DataFrame保存到数据库中
Jul 03 Python
pandas的排序和排名的具体使用
Jul 31 Python
基于Python安装pyecharts所遇的问题及解决方法
Aug 12 Python
Python3批量移动指定文件到指定文件夹方法示例
Sep 02 Python
利用 PyCharm 实现本地代码和远端的实时同步功能
Mar 23 Python
Python如何基于Tesseract实现识别文字功能
Jun 05 Python
Pycharm2020最新激活码|永久激活(附最新激活码和插件的详细教程)
Sep 29 Python
python 如何用urllib与服务端交互(发送和接收数据)
Mar 04 #Python
python 求两个向量的顺时针夹角操作
Mar 04 #Python
python 制作磁力搜索工具
Mar 04 #Python
python抢购软件/插件/脚本附完整源码
Mar 04 #Python
Python 求向量的余弦值操作
Mar 04 #Python
django使用多个数据库的方法实例
Mar 04 #Python
Python使用paramiko连接远程服务器执行Shell命令的实现
Mar 04 #Python
You might like
一个php导出oracle库的php代码
2009/04/20 PHP
PHP及Zend Engine的线程安全模型分析
2011/11/10 PHP
PHP实现下载功能的代码
2012/09/29 PHP
php中current、next与reset函数用法实例
2014/11/17 PHP
PHP也能干大事 随机函数
2015/04/14 PHP
PHP设计模式之迭代器模式
2016/06/17 PHP
Bootstrap插件全集
2016/07/18 Javascript
jquery结合html实现中英文页面切换
2016/11/29 Javascript
详解vue 模版组件的三种用法
2017/07/21 Javascript
PHP自动加载autoload和命名空间的应用小结
2017/12/01 Javascript
使用Vue.js 和Chart.js制作绚丽多彩的图表
2019/06/15 Javascript
详解解决小程序中webview页面多层history返回问题
2019/08/20 Javascript
vue-resource post数据时碰到Django csrf问题的解决
2020/03/13 Javascript
vue添加自定义右键菜单的完整实例
2020/12/08 Vue.js
[01:08:56]DOTA2-DPC中国联赛 正赛 Magma vs LBZS BO3 第一场 2月7日
2021/03/11 DOTA
用Python计算三角函数之atan()方法的使用
2015/05/15 Python
Python3 加密(hashlib和hmac)模块的实现
2017/11/23 Python
Python计算一个给定时间点前一个月和后一个月第一天的方法
2018/05/29 Python
python微信公众号之关注公众号自动回复
2018/10/25 Python
python实现贪吃蛇游戏
2020/03/21 Python
Python使用scrapy爬取阳光热线问政平台过程解析
2019/08/14 Python
python爬虫工具例举说明
2020/11/30 Python
Python实例教程之检索输出月份日历表
2020/12/16 Python
NFL官方在线商店:NFLShop
2020/07/29 全球购物
*p++ 自增p 还是p所指向的变量
2016/07/16 面试题
解释一下钝化(Swap out)
2016/12/26 面试题
个人能力自我鉴赏
2014/01/25 职场文书
家教广告词
2014/03/19 职场文书
绿色学校实施方案
2014/03/31 职场文书
投标担保书范文
2014/04/02 职场文书
安全生产月标语
2014/10/07 职场文书
群众路线教育实践活动整改方案(个人版)
2014/10/25 职场文书
安全教育培训制度
2015/08/06 职场文书
小学班级口号大全
2015/12/25 职场文书
Python中递归以及递归遍历目录详解
2021/10/24 Python
MySql按时,天,周,月进行数据统计
2022/08/14 MySQL