Pytorch通过保存为ONNX模型转TensorRT5的实现


Posted in Python onMay 25, 2020

1 Pytorch以ONNX方式保存模型

def saveONNX(model, filepath):
  '''
  保存ONNX模型
  :param model: 神经网络模型
  :param filepath: 文件保存路径
  '''
  
  # 神经网络输入数据类型
  dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
  torch.onnx.export(model, dummy_input, filepath, verbose=True)

2 利用TensorRT5中ONNX解析器构建Engine

def ONNX_build_engine(onnx_file_path):
  '''
  通过加载onnx文件,构建engine
  :param onnx_file_path: onnx文件路径
  :return: engine
  '''
  # 打印日志
  G_LOGGER = trt.Logger(trt.Logger.WARNING)

  with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
   builder.max_batch_size = 100
   builder.max_workspace_size = 1 << 20

   print('Loading ONNX file from path {}...'.format(onnx_file_path))
   with open(onnx_file_path, 'rb') as model:
    print('Beginning ONNX file parsing')
    parser.parse(model.read())
   print('Completed parsing of ONNX file')

   print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
   engine = builder.build_cuda_engine(network)
   print("Completed creating Engine")

   # 保存计划文件
   # with open(engine_file_path, "wb") as f:
   #  f.write(engine.serialize())
   return engine

3 构建TensorRT运行引擎进行预测

def loadONNX2TensorRT(filepath):
  '''
  通过onnx文件,构建TensorRT运行引擎
  :param filepath: onnx文件路径
  '''
  # 计算开始时间
  Start = time()

  engine = self.ONNX_build_engine(filepath)

  # 读取测试集
  datas = DataLoaders()
  test_loader = datas.testDataLoader()
  img, target = next(iter(test_loader))
  img = img.numpy()
  target = target.numpy()

  img = img.ravel()

  context = engine.create_execution_context()
  output = np.empty((100, 10), dtype=np.float32)

  # 分配内存
  d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
  d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
  bindings = [int(d_input), int(d_output)]

  # pycuda操作缓冲区
  stream = cuda.Stream()
  # 将输入数据放入device
  cuda.memcpy_htod_async(d_input, img, stream)
  # 执行模型
  context.execute_async(100, bindings, stream.handle, None)
  # 将预测结果从从缓冲区取出
  cuda.memcpy_dtoh_async(output, d_output, stream)
  # 线程同步
  stream.synchronize()

  print("Test Case: " + str(target))
  print("Prediction: " + str(np.argmax(output, axis=1)))
  print("tensorrt time:", time() - Start)

  del context
  del engine

补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT

近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。

后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。

是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。

以上这篇Pytorch通过保存为ONNX模型转TensorRT5的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之从if开始语句的征程
Sep 14 Python
python中ConfigParse模块的用法
Sep 29 Python
python中__slots__用法实例
Jun 04 Python
Python随手笔记第一篇(2)之初识列表和元组
Jan 23 Python
浅谈Python中range和xrange的区别
Dec 20 Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 Python
python绘制立方体的方法
Jul 02 Python
浅析python继承与多重继承
Sep 13 Python
pandas的qcut()方法详解
Jul 06 Python
django-filter和普通查询的例子
Aug 12 Python
Python3 合并二叉树的实现
Sep 30 Python
Python安装OpenCV的示例代码
Mar 05 Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
You might like
PHP文本操作类
2006/11/25 PHP
用PHP的ob_start();控制您的浏览器cache!
2006/11/25 PHP
phpExcel导出大量数据出现内存溢出错误的解决方法
2013/02/28 PHP
php生成扇形比例图实例
2013/11/06 PHP
php页码形式分页函数支持静态化地址及ajax分页
2014/03/28 PHP
destoon调用自定义模板及样式的公告栏
2014/06/21 PHP
php使用MySQL保存session会话的方法
2015/06/26 PHP
php断点续传之文件分割合并详解
2016/12/13 PHP
Web Inspector:关于在 Sublime Text 中调试Js的介绍
2013/04/18 Javascript
js中widow.open()方法使用详解
2013/07/30 Javascript
js调试系列 初识控制台
2014/06/18 Javascript
javascript获取元素偏移量的方法有哪些
2014/06/24 Javascript
angularJS 入门基础
2015/02/09 Javascript
jQuery使用hide方法隐藏页面上指定元素的方法
2015/03/30 Javascript
vue 中引用gojs绘制E-R图的方法示例
2018/08/24 Javascript
layer.js open 隐藏滚动条的例子
2019/09/05 Javascript
Node.js API详解之 Error模块用法实例分析
2020/05/14 Javascript
JS如何寻找数组中心索引过程解析
2020/06/01 Javascript
Python 3.x 安装opencv+opencv_contrib的操作方法
2018/04/02 Python
Python3.6基于正则实现的计算器示例【无优化简单注释版】
2018/06/14 Python
python tkinter窗口最大化的实现
2019/07/15 Python
python同步windows和linux文件
2019/08/29 Python
TensorFlow设置日志级别的几种方式小结
2020/02/04 Python
Python使用Paramiko控制liunx第三方库
2020/05/20 Python
Django中F函数的使用示例代码详解
2020/07/06 Python
Python 代码调试技巧示例代码
2020/08/11 Python
解决TensorFlow训练模型及保存数量限制的问题
2021/03/03 Python
HTML5 Blob 实现文件下载功能的示例代码
2019/11/29 HTML / CSS
远程Wi-Fi宠物监控相机:Petcube
2017/04/26 全球购物
长曲棍球装备:Lacrosse Monkey
2020/12/02 全球购物
安全责任书怎么写
2014/07/28 职场文书
火锅店的活动方案
2014/08/15 职场文书
检讨书大全
2015/01/27 职场文书
企业安全生产规章制度
2015/08/06 职场文书
如何撰写出一份完美的商业计划书?
2019/07/12 职场文书
pandas 实现将NaN转换为None
2021/05/14 Python