Pytorch转tflite方式


Posted in Python onMay 25, 2020

目标是想把在服务器上用pytorch训练好的模型转换为可以在移动端运行的tflite模型。

最直接的思路是想把pytorch模型转换为tensorflow的模型,然后转换为tflite。但是这个转换目前没有发现比较靠谱的方法。

经过调研发现最新的tflite已经支持直接从keras模型的转换,所以可以采用keras作为中间转换的桥梁,这样就能充分利用keras高层API的便利性。

转换的基本思想就是用pytorch中的各层网络的权重取出来后直接赋值给keras网络中的对应layer层的权重。

转换为Keras模型后,再通过tf.contrib.lite.TocoConverter把模型直接转为tflite.

下面是一个例子,假设转换的是一个两层的CNN网络。

import tensorflow as tf
from tensorflow import keras
import numpy as np

import torch
from torchvision import models
import torch.nn as nn
# import torch.nn.functional as F
from torch.autograd import Variable

class PytorchNet(nn.Module):
 def __init__(self):
 super(PytorchNet, self).__init__()
 conv1 = nn.Sequential(
  nn.Conv2d(3, 32, 3, 2),
  nn.BatchNorm2d(32),
  nn.ReLU(inplace=True),
  nn.MaxPool2d(2, 2))
 conv2 = nn.Sequential(
  nn.Conv2d(32, 64, 3, 1, groups=1),
  nn.BatchNorm2d(64),
  nn.ReLU(inplace=True),
  nn.MaxPool2d(2, 2))
 self.feature = nn.Sequential(conv1, conv2)
 self.init_weights()

 def forward(self, x):
 return self.feature(x)

 def init_weights(self):
 for m in self.modules():
  if isinstance(m, nn.Conv2d):
  nn.init.kaiming_normal_(
   m.weight.data, mode='fan_out', nonlinearity='relu')
  if m.bias is not None:
   m.bias.data.zero_()
  if isinstance(m, nn.BatchNorm2d):
  m.weight.data.fill_(1)
  m.bias.data.zero_()

def KerasNet(input_shape=(224, 224, 3)):
 image_input = keras.layers.Input(shape=input_shape)
 # conv1
 network = keras.layers.Conv2D(
 32, (3, 3), strides=(2, 2), padding="valid")(image_input)
 network = keras.layers.BatchNormalization(
 trainable=False, fused=False)(network)
 network = keras.layers.Activation("relu")(network)
 network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)

 # conv2
 network = keras.layers.Conv2D(
 64, (3, 3), strides=(1, 1), padding="valid")(network)
 network = keras.layers.BatchNormalization(
 trainable=False, fused=True)(network)
 network = keras.layers.Activation("relu")(network)
 network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network)

 model = keras.Model(inputs=image_input, outputs=network)

 return model

class PytorchToKeras(object):
 def __init__(self, pModel, kModel):
 super(PytorchToKeras, self)
 self.__source_layers = []
 self.__target_layers = []
 self.pModel = pModel
 self.kModel = kModel
 tf.keras.backend.set_learning_phase(0)

 def __retrieve_k_layers(self):
 for i, layer in enumerate(self.kModel.layers):
  if len(layer.weights) > 0:
  self.__target_layers.append(i)

 def __retrieve_p_layers(self, input_size):

 input = torch.randn(input_size)
 input = Variable(input.unsqueeze(0))
 hooks = []

 def add_hooks(module):

  def hook(module, input, output):
  if hasattr(module, "weight"):
   # print(module)
   self.__source_layers.append(module)

  if not isinstance(module, nn.ModuleList) and not isinstance(module, nn.Sequential) and module != self.pModel:
  hooks.append(module.register_forward_hook(hook))

 self.pModel.apply(add_hooks)

 self.pModel(input)
 for hook in hooks:
  hook.remove()

 def convert(self, input_size):
 self.__retrieve_k_layers()
 self.__retrieve_p_layers(input_size)

 for i, (source_layer, target_layer) in enumerate(zip(self.__source_layers, self.__target_layers)):
  print(source_layer)
  weight_size = len(source_layer.weight.data.size())
  transpose_dims = []
  for i in range(weight_size):
  transpose_dims.append(weight_size - i - 1)
  if isinstance(source_layer, nn.Conv2d):
  transpose_dims = [2,3,1,0]
  self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(
  ).transpose(transpose_dims), source_layer.bias.data.numpy()])
  elif isinstance(source_layer, nn.BatchNorm2d):
  self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(), source_layer.bias.data.numpy(),
        source_layer.running_mean.data.numpy(), source_layer.running_var.data.numpy()])
 def save_model(self, output_file):
 self.kModel.save(output_file)

 def save_weights(self, output_file):
 self.kModel.save_weights(output_file, save_format='h5')

pytorch_model = PytorchNet()
keras_model = KerasNet(input_shape=(224, 224, 3))

torch.save(pytorch_model, 'test.pth')

#Load the pretrained model
pytorch_model = torch.load('test.pth')

# #Time to transfer weights
converter = PytorchToKeras(pytorch_model, keras_model)
converter.convert((3, 224, 224))

# #Save the converted keras model for later use
# converter.save_weights("keras.h5")
converter.save_model("keras_model.h5")

# convert keras model to tflite model
converter = tf.contrib.lite.TocoConverter.from_keras_model_file(
 "keras_model.h5")
tflite_model = converter.convert()
open("convert_model.tflite", "wb").write(tflite_model)

补充知识:tensorflow模型转换成tensorflow lite模型

1.把graph和网络模型打包在一个文件中

bazel build tensorflow/python/tools:freeze_graph && \
 bazel-bin/tensorflow/python/tools/freeze_graph \
 --input_graph=eval_graph_def.pb \
 --input_checkpoint=checkpoint \
 --output_graph=frozen_eval_graph.pb \
 --output_node_names=outputs

For example:

bazel-bin/tensorflow/python/tools/freeze_graph \ 
 --input_graph=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_eval.pbtxt \
 --input_checkpoint=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \
 --output_graph=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \
 --output_node_names=MobilenetV1/Predictions/Reshape_1

2.把第一步中生成的tensorflow pb模型转换为tf lite模型

转换前需要先编译转换工具

bazel build tensorflow/contrib/lite/toco:toco

转换分两种,一种的转换为float的tf lite,另一种可以转换为对模型进行unit8的量化版本的模型。两种方式如下:

非量化的转换:

./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \ 官网给的这个路径不对       
./bazel-bin/tensorflow/contrib/lite/toco/toco \         
 —input_file=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \  
 —output_file=./mobilenet_v1_1.0_224/tflite_model_test.tflite \  
 --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \       
 --inference_type=FLOAT \           
 --input_shape="1,224, 224,3" \           
 --input_array=input \            
 --output_array=MobilenetV1/Predictions/Reshape_1

量化方式的转换(注意,只有量化训练的模型才能进行量化的tf_lite转换):

./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \
./bazel-bin/tensorflow/contrib/lite/toco/toco \
 --input_file=frozen_eval_graph.pb \
 --output_file=tflite_model.tflite \
 --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
 --inference_type=QUANTIZED_UINT8 \
 --input_shape="1,224, 224,3" \
 --input_array=input \
 --output_array=outputs \
 --std_value=127.5 --mean_value=127.5

以上这篇Pytorch转tflite方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python深入学习之闭包
Aug 31 Python
Python使用sftp实现上传和下载功能(实例代码)
Mar 14 Python
python 定义给定初值或长度的list方法
Jun 23 Python
对python读取CT医学图像的实例详解
Jan 24 Python
利用PyCharm Profile分析异步爬虫效率详解
May 08 Python
python反编译学习之字节码详解
May 19 Python
使用tensorflow DataSet实现高效加载变长文本输入
Jan 20 Python
浅谈keras的深度模型训练过程及结果记录方式
Jan 24 Python
Python3查找列表中重复元素的个数的3种方法详解
Feb 13 Python
python 删除excel表格重复行,数据预处理操作
Jul 06 Python
python神经网络编程之手写数字识别
May 08 Python
基于Python绘制子图及子图刻度的变换等的问题
May 23 Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
Win10用vscode打开anaconda环境中的python出错问题的解决
May 25 #Python
keras .h5转移动端的.tflite文件实现方式
May 25 #Python
Python虚拟环境venv用法详解
May 25 #Python
将keras的h5模型转换为tensorflow的pb模型操作
May 25 #Python
You might like
多文件上传的例子
2006/10/09 PHP
基于mysql的论坛(5)
2006/10/09 PHP
PHP对象递归引用造成内存泄漏分析
2014/08/28 PHP
PHP+Ajax+JS实现多图上传
2016/05/07 PHP
javascript 日期时间函数(经典+完善+实用)
2009/05/27 Javascript
javascript offsetX与layerX区别
2010/03/12 Javascript
Jquery中find与each方法用法实例
2015/02/04 Javascript
Javascript核心读书有感之语句
2015/02/11 Javascript
AngularJS 实现弹性盒子布局的方法
2016/08/30 Javascript
Bootstrap Modal遮罩弹出层代码分享
2016/11/21 Javascript
JS正则表达式修饰符中multiline(/m)用法分析
2016/12/27 Javascript
解决bootstrap中使用modal加载kindeditor时弹出层文本框不能输入的问题
2017/06/05 Javascript
vue中element组件样式修改无效的解决方法
2018/02/03 Javascript
详解如何解决Vue和vue-template-compiler版本之间的问题
2018/09/17 Javascript
angular使用md5,CryptoJS des加密的方法
2019/06/03 Javascript
vue基本使用--refs获取组件或元素的实例
2019/11/07 Javascript
如何基于jQuery实现五角星评分
2020/09/02 jQuery
微信小程序实现列表左右滑动
2020/11/19 Javascript
python基础教程之常用运算符
2014/08/29 Python
Python文件与文件夹常见基本操作总结
2016/09/19 Python
Python定义二叉树及4种遍历方法实例详解
2018/07/05 Python
使用Python自动化破解自定义字体混淆信息的方法实例
2019/02/13 Python
selenium+python实现自动登陆QQ邮箱并发送邮件功能
2019/12/13 Python
Python3监控windows,linux系统的CPU、硬盘、内存使用率和各个端口的开启情况详细代码实例
2020/03/18 Python
iPython pylab模式启动方式
2020/04/24 Python
css3实现蒙版弹幕功能
2019/06/18 HTML / CSS
HTML5之SVG 2D入门6—视窗坐标系与用户坐标系及变换概述
2013/01/30 HTML / CSS
北美Newegg打造的全球尖货海购平台:tt海购
2018/09/28 全球购物
Dogeared官网:在美国手工制作的珠宝
2019/08/24 全球购物
应用艺术毕业生的自我评价
2013/12/04 职场文书
什么样的创业计划书可行性高?
2014/02/01 职场文书
鲁迅故里导游词
2015/02/05 职场文书
工程部岗位职责范本
2015/04/11 职场文书
公司文体活动总结
2015/05/07 职场文书
投诉信回复范文
2015/07/03 职场文书
如何利用Python实现n*n螺旋矩阵
2022/01/18 Python