Pytorch通过保存为ONNX模型转TensorRT5的实现


Posted in Python onMay 25, 2020

1 Pytorch以ONNX方式保存模型

def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python网站验证码识别
Jan 25 Python
查看Django和flask版本的方法
May 14 Python
Python字符串、整数、和浮点型数相互转换实例
Aug 04 Python
在Pycharm中自动添加时间日期作者等信息的方法
Jan 16 Python
python实现扫描局域网指定网段ip的方法
Apr 16 Python
python openCV获取人脸部分并存储功能
Aug 28 Python
Python input函数使用实例解析
Nov 22 Python
python线程join方法原理解析
Feb 11 Python
在keras 中获取张量 tensor 的维度大小实例
Jun 10 Python
python + selenium 刷B站播放量的实例代码
Jun 12 Python
基于Python的一个自动录入表格的小程序
Aug 05 Python
Python 使用xlwt模块将多行多列数据循环写入excel文档的操作
Nov 10 Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
You might like
星际争霸任务指南——神族
2020/03/04 星际争霸
php防止恶意刷新与刷票的方法
2014/11/21 PHP
PHP命名空间简单用法示例
2018/12/28 PHP
提高代码性能技巧谈—以创建千行表格为例
2006/07/01 Javascript
用js 让图片在 div或dl里 居中,底部对齐
2008/01/21 Javascript
整理8个很棒的 jQuery 倒计时插件和教程
2011/12/12 Javascript
jquery text()方法取标签中的文本
2014/07/25 Javascript
jQuery中[attribute=value]选择器用法实例
2014/12/31 Javascript
js常用系统函数用法实例分析
2015/01/12 Javascript
你所不了解的javascript操作DOM的细节知识点(一)
2015/06/17 Javascript
两种JS实现屏蔽鼠标右键的方法
2020/08/20 Javascript
5个最顶级jQuery图表类库插件【jquery插件库】
2016/05/05 Javascript
基于JS代码实现当鼠标悬停表格上显示这一格的全部内容
2016/06/12 Javascript
jQuery学习笔记之回调函数
2016/08/15 Javascript
JavaScript中浅讲ajax图文详解
2016/11/11 Javascript
jQ处理xml文件和xml字符串的方法(详解)
2016/11/22 Javascript
node文件批量重命名的方法示例
2017/10/23 Javascript
关于js陀螺仪的理解分析
2019/04/11 Javascript
java和js实现的洗牌小程序
2019/09/30 Javascript
JS实现多功能计算器
2020/10/28 Javascript
vue3弹出层V3Popup实例详解
2021/01/04 Vue.js
[01:42]辉夜杯战队访谈宣传片—FANTUAN
2015/12/25 DOTA
[01:04:02]DOTA2-DPC中国联赛 正赛 Elephant vs IG BO3 第二场 1月24日
2021/03/11 DOTA
python在linux中输出带颜色的文字的方法
2014/06/19 Python
python将.ppm格式图片转换成.jpg格式文件的方法
2018/10/27 Python
python多任务及返回值的处理方法
2019/01/22 Python
keras的load_model实现加载含有参数的自定义模型
2020/06/22 Python
小区消防演习方案
2014/02/21 职场文书
表彰大会主持词
2014/03/26 职场文书
教师个人教学反思
2016/02/23 职场文书
Python如何把不同类型数据的json序列化
2021/04/30 Python
postgres之jsonb属性的使用操作
2021/06/23 PostgreSQL
MySQL的安装与配置详细教程
2021/06/26 MySQL
jackson json序列化实现首字母大写,第二个字母需小写
2021/06/29 Java/Android
关于Python中进度条的六个实用技巧分享
2022/04/05 Python
Python语法学习之进程的创建与常用方法详解
2022/04/08 Python