Pytorch转onnx、torchscript方式


Posted in Python onMay 25, 2020

前言

本文将介绍如何使用ONNX将PyTorch中训练好的模型(.pt、.pth)型转换为ONNX格式,然后将其加载到Caffe2中。需要安装好onnx和Caffe2。

PyTorch及ONNX环境准备

为了正常运行ONNX,我们需要安装最新的Pytorch,你可以选择源码安装:

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
mkdir build && cd build
sudo cmake .. -DPYTHON_INCLUDE_DIR=/usr/include/python3.6 -DUSE_MPI=OFF
make install
export PYTHONPATH=$PYTHONPATH:/opt/pytorch/build

其中 "/opt/pytorch/build"是前面build pytorch的目。

or conda安装

conda install pytorch torchvision -c pytorch

安装ONNX的库

conda install -c conda-forge onnx

onnx-caffe2 安装

pip3 install onnx-caffe2

Pytorch模型转onnx

在PyTorch中导出模型通过跟踪工作。要导出模型,请调用torch.onnx.export()函数。这将执行模型,记录运算符用于计算输出的轨迹。因为_export运行模型,我们需要提供输入张量x。

这个张量的值并不重要; 它可以是图像或随机张量,只要它是正确的大小。更多详细信息,请查看torch.onnx documentation文档。

# 输入模型
example = torch.randn(batch_size, 1, 224, 224, requires_grad=True)

# 导出模型
torch_out = torch_out = torch.onnx.export(model, # model being run
    example, # model input (or a tuple for multiple inputs)
    "peleeNet.onnx",
 verbose=False, # store the trained parameter weights inside the model file
 training=False,
 do_constant_folding=True,
 input_names=['input'],
 output_names=['output'])

其中torch_out是执行模型后的输出,通常以忽略此输出。转换得到onnx后可以使用OpenCV的 cv::dnn::readNetFromONNX or cv::dnn::readNet进行模型加载推理了。

还可以进一步将onnx模型转换为ncnn进而部署到移动端。这就需要ncnn的onnx2ncnn工具了.

编译ncnn源码,生成 onnx2ncnn。

其中onnx转换模型时有一些冗余,可以使用用工具简化一些onnx模型。

pip3 install onnx-simplifier

简化onnx模型

python3 -m onnxsim pnet.onnx pnet-sim.onnx

转换成ncnn

onnx2ncnn pnet-sim.onnx pnet.param pnet.bin

ncnn 加载模型做推理

Pytorch模型转torch script

pytorch 加入libtorch前端处理,集体步骤为:

Pytorch转onnx、torchscript方式

以mtcnn pnet为例

# convert pytorch model to torch script
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 12, 12).to(device)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(pnet, example)
# Save traced model
traced_script_module.save("pnet_model_final.pth")

C++调用如下所示:

#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) 
{
 if (argc != 2) 
 {
 std::cerr << "usage: example-app <path-to-exported-script-module>\n";
 return -1;
 }

 // Deserialize the ScriptModule from a file using torch::jit::load().
 std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);

 assert(module != nullptr);
 std::cout << "ok\n";
}
Python 相关文章推荐
Python中的并发编程实例
Jul 07 Python
对Python新手编程过程中如何规避一些常见问题的建议
Apr 01 Python
Python正则表达式分组概念与用法详解
Jun 24 Python
python reduce 函数使用详解
Dec 05 Python
python操作oracle的完整教程分享
Jan 30 Python
python实现控制台打印的方法
Jan 12 Python
Python实现微信翻译机器人的方法
Aug 13 Python
python中的RSA加密与解密实例解析
Nov 18 Python
python的slice notation的特殊用法详解
Dec 27 Python
详解pandas中iloc, loc和ix的区别和联系
Mar 09 Python
Pytorch学习之torch用法----比较操作(Comparison Ops)
Jun 28 Python
pandas针对excel处理的实现
Jan 15 Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
基于pandas向csv添加新的行和列
May 25 #Python
Python如何把十进制数转换成ip地址
May 25 #Python
tensorflow模型转ncnn的操作方式
May 25 #Python
MxNet预训练模型到Pytorch模型的转换方式
May 25 #Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
You might like
如何在PHP中使用Oracle数据库(6)
2006/10/09 PHP
浅析get与post的一些特殊情况
2014/07/28 PHP
php通过exif_read_data函数获取图片的exif信息
2015/05/21 PHP
php基础教程
2015/08/26 PHP
php原生导出excel文件的两种方法(推荐)
2016/11/19 PHP
PHP 计算两个时间段之间交集的天数示例
2019/10/24 PHP
gearman管理工具GearmanManager的安装与php使用方法示例
2020/02/27 PHP
IE8 下的Js错误HTML Parsing Error...
2009/08/14 Javascript
javascript日期对象格式化为字符串的实现方法
2014/01/14 Javascript
jQuery CSS()方法改变现有的CSS样式
2014/08/20 Javascript
jQuery与getJson结合的用法实例
2015/08/07 Javascript
javascript精确统计网站访问量实例代码
2015/12/19 Javascript
图解js图片轮播效果
2015/12/20 Javascript
基于javascript实现简单的抽奖系统
2020/04/15 Javascript
js实现div在页面拖动效果
2016/05/04 Javascript
jQuery属性选择器用法示例
2016/09/09 Javascript
jQuery实现动态生成表格并为行绑定单击变色动作的方法
2017/04/17 jQuery
vue 2.x 中axios 封装的get 和post方法
2018/02/28 Javascript
实例详解vue.js浅度监听和深度监听及watch用法
2018/08/16 Javascript
详解webpack编译速度提升之DllPlugin
2019/02/05 Javascript
vue2.0+vue-router构建一个简单的列表页的示例代码
2019/02/13 Javascript
JS中的算法与数据结构之二叉查找树(Binary Sort Tree)实例详解
2019/08/16 Javascript
python PIL模块与随机生成中文验证码
2016/02/27 Python
Python科学计算之NumPy入门教程
2017/01/15 Python
python实现requests发送/上传多个文件的示例
2018/06/04 Python
使用pyecharts生成Echarts网页的实例
2019/08/12 Python
matplotlib设置颜色、标记、线条,让你的图像更加丰富(推荐)
2020/09/25 Python
阿迪达斯印尼官方网站:adidas印尼
2020/02/10 全球购物
计算机专业个人简短的自我评价
2013/10/23 职场文书
项目申报专员岗位职责
2014/07/09 职场文书
初中重阳节活动总结
2015/05/05 职场文书
2016年大学迎新工作总结
2015/10/14 职场文书
python tkinter模块的简单使用
2021/04/07 Python
Tensorflow与RNN、双向LSTM等的踩坑记录及解决
2021/05/31 Python
微信小程序基础教程之echart的使用
2021/06/01 Javascript
redis客户端实现高可用读写分离的方式详解
2021/07/04 Redis