tensorflow pb to tflite 精度下降详解


Posted in Python onMay 25, 2020

之前希望在手机端使用深度模型做OCR,于是尝试在手机端部署tensorflow模型,用于图像分类。

思路主要是想使用tflite部署到安卓端,但是在使用tflite的时候发现模型的精度大幅度下降,已经不能支持业务需求了,最后就把OCR模型调用写在服务端了,但是精度下降的原因目前也没有找到,现在这里记录一下。

工作思路:

1.训练图像分类模型;2.模型固化成pb;3.由pb转成tflite文件;

但是使用python 的tf interpreter 调用tflite文件就已经出现精度下降的问题,android端部署也是一样。

1.网络结构

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import tensorflow as tf
slim = tf.contrib.slim
 
def ttnet(images, num_classes=10, is_training=False,
   dropout_keep_prob=0.5,
   prediction_fn=slim.softmax,
   scope='TtNet'):
 end_points = {}
 
 with tf.variable_scope(scope, 'TtNet', [images, num_classes]):
 net = slim.conv2d(images, 32, [3, 3], scope='conv1')
 # net = slim.conv2d(images, 64, [3, 3], scope='conv1_2')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
 net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn1')
 # net = slim.conv2d(net, 128, [3, 3], scope='conv2_1')
 net = slim.conv2d(net, 64, [3, 3], scope='conv2')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
 net = slim.conv2d(net, 128, [3, 3], scope='conv3')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool3')
 net = slim.conv2d(net, 256, [3, 3], scope='conv4')
 net = slim.max_pool2d(net, [2, 2], 2, scope='pool4')
 net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn2')
 # net = slim.conv2d(net, 512, [3, 3], scope='conv5')
 # net = slim.max_pool2d(net, [2, 2], 2, scope='pool5')
 net = slim.flatten(net)
 end_points['Flatten'] = net
 
 # net = slim.fully_connected(net, 1024, scope='fc3')
 net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
      scope='dropout3')
 logits = slim.fully_connected(net, num_classes, activation_fn=None,
         scope='fc4') 
 end_points['Logits'] = logits
 end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
 
 return logits, end_points
ttnet.default_image_size = 28
 
def ttnet_arg_scope(weight_decay=0.0):
 with slim.arg_scope(
  [slim.conv2d, slim.fully_connected],
  weights_regularizer=slim.l2_regularizer(weight_decay),
  weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
  activation_fn=tf.nn.relu) as sc:
 return sc

基于slim,由于是一个比较简单的分类问题,网络结构也很简单,几个卷积加池化。

测试效果是很棒的。真实样本测试集能达到99%+的准确率。

2.模型固化,生成pb文件

#coding:utf-8
 
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import nets_factory
import cv2
import os
import numpy as np
from datasets import dataset_factory
from preprocessing import preprocessing_factory
from tensorflow.python.platform import gfile
slim = tf.contrib.slim
#todo
#support arbitray image size and num_class
 
tf.app.flags.DEFINE_string(
 'checkpoint_path', '/tmp/tfmodel/',
 'The directory where the model was written to or an absolute path to a '
 'checkpoint file.')
 
tf.app.flags.DEFINE_string(
 'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
tf.app.flags.DEFINE_string(
 'preprocessing_name', None, 'The name of the preprocessing to use. If left '
 'as `None`, then the model_name flag is used.')
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer(
 'eval_image_size', None, 'Eval image size')
tf.app.flags.DEFINE_integer(
 'eval_image_height', None, 'Eval image height')
tf.app.flags.DEFINE_integer(
 'eval_image_width', None, 'Eval image width')
tf.app.flags.DEFINE_string(
 'export_path', './ttnet_1.0_37_32.pb', 'the export path of the pd file')
FLAGS = tf.app.flags.FLAGS
NUM_CLASSES = 37
 
def main(_):
 network_fn = nets_factory.get_network_fn(
  FLAGS.model_name,
  num_classes=NUM_CLASSES,
  is_training=False)
 # pre_image = tf.placeholder(tf.float32, [None, None, 3], name='input_data')
 # preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
 # image_preprocessing_fn = preprocessing_factory.get_preprocessing(
 #  preprocessing_name,
 #  is_training=False)
 # image = image_preprocessing_fn(pre_image, FLAGS.eval_image_height, FLAGS.eval_image_width)
 # images2 = tf.expand_dims(image, 0)
 images2 = tf.placeholder(tf.float32, (None,32, 32, 3),name='input_data')
 logits, endpoints = network_fn(images2)
 with tf.Session() as sess:
 output = tf.identity(endpoints['Predictions'],name="output_data")
 with gfile.GFile(FLAGS.export_path, 'wb') as f:
  f.write(sess.graph_def.SerializeToString())
 
if __name__ == '__main__':
 tf.app.run()

3.生成tflite文件

import tensorflow as tf
 
graph_def_file = "/datastore1/Colonist_Lord/Colonist_Lord/workspace/models/model_files/passport_model_with_tflite/ocr_frozen.pb"
input_arrays = ["input_data"]
output_arrays = ["output_data"]
 
converter = tf.lite.TFLiteConverter.from_frozen_graph(
 graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

使用pb文件进行测试,效果正常;使用tflite文件进行测试,精度下降严重。下面附上pb与tflite测试代码。

pb测试代码

with tf.gfile.GFile(graph_filename, "rb") as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 
with tf.Graph().as_default() as graph:
 tf.import_graph_def(graph_def)
 input_node = graph.get_tensor_by_name('import/input_data:0')
 output_node = graph.get_tensor_by_name('import/output_data:0')
 with tf.Session() as sess:
  for image_file in image_files:
   abs_path = os.path.join(image_folder, image_file)
   img = cv2.imread(abs_path).astype(np.float32)
   img = cv2.resize(img, (int(input_node.shape[1]), int(input_node.shape[2])))
   output_data = sess.run(output_node, feed_dict={input_node: [img]})
   index = np.argmax(output_data)
   label = dict_laebl[index]
   dst_floder = os.path.join(result_folder, label)
   if not os.path.exists(dst_floder):
    os.mkdir(dst_floder)
   cv2.imwrite(os.path.join(dst_floder, image_file), img)
   count += 1

tflite测试代码

model_path = "converted_model.tflite" #"/datastore1/Colonist_Lord/Colonist_Lord/data/passport_char/ocr.tflite"
interpreter = tf.contrib.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
 
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
for image_file in image_files:
 abs_path = os.path.join(image_folder,image_file)
 img = cv2.imread(abs_path).astype(np.float32)
 img = cv2.resize(img, tuple(input_details[0]['shape'][1:3]))
 # input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
 interpreter.set_tensor(input_details[0]['index'], [img])
 
 interpreter.invoke()
 output_data = interpreter.get_tensor(output_details[0]['index'])
 index = np.argmax(output_data)
 label = dict_laebl[index]
 dst_floder = os.path.join(result_folder,label)
 if not os.path.exists(dst_floder):
  os.mkdir(dst_floder)
 cv2.imwrite(os.path.join(dst_floder,image_file),img)
 count+=1

最后也算是绕过这个问题解决了业务需求,后面有空的话,还是会花时间研究一下这个问题。

如果有哪个大佬知道原因,希望不吝赐教。

补充知识:.pb 转tflite代码,使用量化,减小体积,converter.post_training_quantize = True

import tensorflow as tf

path = "/home/python/Downloads/a.pb" # pb文件位置和文件名
inputs = ["input_images"] # 模型文件的输入节点名称
classes = ['feature_fusion/Conv_7/Sigmoid','feature_fusion/concat_3'] # 模型文件的输出节点名称
# converter = tf.contrib.lite.TocoConverter.from_frozen_graph(path, inputs, classes, input_shapes={'input_images':[1, 320, 320, 3]})
converter = tf.lite.TFLiteConverter.from_frozen_graph(path, inputs, classes,
              input_shapes={'input_images': [1, 320, 320, 3]})
converter.post_training_quantize = True
tflite_model = converter.convert()
open("/home/python/Downloads/aNew.tflite", "wb").write(tflite_model)

以上这篇tensorflow pb to tflite 精度下降详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用PDB库调试程序
Apr 05 Python
python Socket之客户端和服务端握手详解
Sep 18 Python
python实现报表自动化详解
Nov 16 Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 Python
python+opencv实现摄像头调用的方法
Jun 22 Python
简单了解django缓存方式及配置
Jul 19 Python
基于Python的微信机器人开发 微信登录和获取好友列表实现解析
Aug 21 Python
Django之模板层的实现代码
Sep 09 Python
Django 请求Request的具体使用方法
Nov 11 Python
tensorflow实现残差网络方式(mnist数据集)
May 26 Python
idea2020手动安装python插件的实现方法
Jul 17 Python
Python连接Impala实现步骤解析
Aug 04 Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 #Python
5行Python代码实现图像分割的步骤详解
May 25 #Python
Win10用vscode打开anaconda环境中的python出错问题的解决
May 25 #Python
You might like
PHP也可以?成Shell Script
2006/10/09 PHP
cakephp打印sql语句的方法
2015/02/13 PHP
php array_walk 对数组中的每个元素应用用户自定义函数详解
2016/11/18 PHP
Laravel5.* 打印出执行的sql语句的方法
2017/07/24 PHP
php 删除一维数组中某一个值元素的操作方法
2018/02/01 PHP
PHP实现的杨辉三角求解算法分析
2019/03/11 PHP
php实现登录页面的简单实例
2019/09/29 PHP
学习YUI.Ext第五日--做拖放Darg&Drop
2007/03/10 Javascript
完整显示当前日期和时间的JS代码
2007/09/17 Javascript
ext监听事件方法[初级篇]
2008/04/27 Javascript
ajax无刷新动态调用股票信息(改良版)
2008/11/01 Javascript
js更优雅的兼容
2010/08/12 Javascript
基于jquery的lazy loader插件实现图片的延迟加载[简单使用]
2011/05/07 Javascript
玩转jQuery按钮 请告诉我你最喜欢哪些?
2012/01/08 Javascript
javascript判断并获取注册表中可信任站点的方法
2015/06/01 Javascript
详解jQuery向动态生成的内容添加事件响应jQuery live()方法
2015/11/02 Javascript
轻松实现Bootstrap图片轮播
2020/04/20 Javascript
基于JavaScript实现瀑布流布局(二)
2016/01/26 Javascript
基于JS实现弹出一个隐藏的div窗口body页面变成灰色并且不可被编辑
2016/12/14 Javascript
Jqprint实现页面打印
2017/01/06 Javascript
小程序云开发实战小结
2018/10/25 Javascript
js中位数不足自动补位扩展padLeft、padRight实现代码
2020/04/06 Javascript
Python函数式编程指南(一):函数式编程概述
2015/06/24 Python
利用Python将时间或时间间隔转为ISO 8601格式方法示例
2017/09/05 Python
Python中列表与元组的乘法操作示例
2018/02/10 Python
windows下python 3.6.4安装配置图文教程
2018/08/21 Python
python创建文件时去掉非法字符的方法
2018/10/31 Python
Python数据类型之Number数字操作实例详解
2019/05/08 Python
使用OpenCV实现仿射变换—缩放功能
2019/08/29 Python
python psutil监控进程实例
2019/12/17 Python
python字典通过值反查键的实现(简洁写法)
2020/09/30 Python
python实现简单文件读写函数
2021/02/25 Python
十八届三中全会个人学习材料
2014/02/13 职场文书
毕业自我鉴定总结
2014/03/24 职场文书
银行委托书范本
2014/04/04 职场文书
MySQL分库分表详情
2021/09/25 MySQL