keras模型保存为tensorflow的二进制模型方式


Posted in Python onMay 25, 2020

最近需要将使用keras训练的模型移植到手机上使用, 因此需要转换到tensorflow的二进制模型。

折腾一下午,终于找到一个合适的方法,废话不多说,直接上代码:

# coding=utf-8
import sys

from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 """
 Freezes the state of a session into a prunned computation graph.

 Creates a new computation graph where variable nodes are replaced by
 constants taking their current value in the session. The new graph will be
 prunned so subgraphs that are not neccesary to compute the requested
 outputs are removed.
 @param session The TensorFlow session to be frozen.
 @param keep_var_names A list of variable names that should not be frozen,
       or None to freeze all the variables in the graph.
 @param output_names Names of the relevant graph outputs.
 @param clear_devices Remove the device directives from the graph for better portability.
 @return The frozen graph definition.
 """
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
              output_names, freeze_var_names)
  return frozen_graph

input_fld = sys.path[0]
weight_file = 'your_model.h5'
output_graph_name = 'tensor_model.pb'

output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
 os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)

print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])

from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

上面代码实现保存到当前目录的tensor_model目录下。

验证:

import tensorflow as tf
import numpy as np
import PIL.Image as Image
import cv2

def recognize(jpg_path, pb_file_path):
 with tf.Graph().as_default():
  output_graph_def = tf.GraphDef()

  with open(pb_file_path, "rb") as f:
   output_graph_def.ParseFromString(f.read())
   tensors = tf.import_graph_def(output_graph_def, name="")
   print tensors

  with tf.Session() as sess:
   init = tf.global_variables_initializer()
   sess.run(init)

   op = sess.graph.get_operations()
   
   for m in op:
    print(m.values())

   input_x = sess.graph.get_tensor_by_name("convolution2d_1_input:0") #具体名称看上一段代码的input.name
   print input_x

   out_softmax = sess.graph.get_tensor_by_name("activation_4/Softmax:0") #具体名称看上一段代码的output.name

   print out_softmax

   img = cv2.imread(jpg_path, 0)
   img_out_softmax = sess.run(out_softmax,
          feed_dict={input_x: 1.0 - np.array(img).reshape((-1,28, 28, 1)) / 255.0})

   print "img_out_softmax:", img_out_softmax
   prediction_labels = np.argmax(img_out_softmax, axis=1)
   print "label:", prediction_labels

pb_path = 'tensorflow_model/constant_graph_weights.pb'
img = 'test/6/8_48.jpg'
recognize(img, pb_path)

补充知识:如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlow serving环境调用

首先keras训练好的模型通过自带的model.save()保存下来是 .model (.h5) 格式的文件

模型载入是通过 my_model = keras . models . load_model( filepath )

要将该模型转换为.pb 格式的TensorFlow 模型,代码如下:

# -*- coding: utf-8 -*-
from keras.layers.core import Activation, Dense, Flatten
from keras.layers.embeddings import Embedding
from keras.layers.recurrent import LSTM
from keras.layers import Dropout
from keras.layers.wrappers import Bidirectional
from keras.models import Sequential,load_model
from keras.preprocessing import sequence
from sklearn.model_selection import train_test_split
import collections
from collections import defaultdict
import jieba
import numpy as np
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
              output_names, freeze_var_names)
  return frozen_graph
input_fld = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/'
weight_file = 'biLSTM_brand_recognize.model'
output_graph_name = 'tensor_model_v3.pb'

output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
 os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)

print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=True)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

然后模型就存成了.pb格式的文件

问题就来了,这样存下来的.pb格式的文件是frozen model

如果通过TensorFlow serving 启用模型的话,会报错:

E tensorflow_serving/core/aspired_versions_manager.cc:358] Servable {name: mnist version: 1} cannot be loaded: Not found: Could not find meta graph def matching supplied tags: { serve }. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: `saved_model_cli`

因为TensorFlow serving 希望读取的是saved model

于是需要将frozen model 转化为 saved model 格式,解决方案如下:

from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants

export_dir = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/saved_model'
graph_pb = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/tensorflow_model/tensor_model.pb'

builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

with tf.gfile.GFile(graph_pb, "rb") as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())

sigs = {}

with tf.Session(graph=tf.Graph()) as sess:
 # name="" is important to ensure we don't get spurious prefixing
 tf.import_graph_def(graph_def, name="")
 g = tf.get_default_graph()
 inp = g.get_tensor_by_name(net_model.input.name)
 out = g.get_tensor_by_name(net_model.output.name)

 sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
  tf.saved_model.signature_def_utils.predict_signature_def(
   {"in": inp}, {"out": out})

 builder.add_meta_graph_and_variables(sess,
           [tag_constants.SERVING],
           signature_def_map=sigs)
builder.save()

于是保存下来的saved model 文件夹下就有两个文件:

saved_model.pb variables

其中variables 可以为空

于是将.pb 模型导入serving再读取,成功!

以上这篇keras模型保存为tensorflow的二进制模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
php使用递归与迭代实现快速排序示例
Jan 23 Python
python多线程方式执行多个bat代码
Jun 07 Python
解读! Python在人工智能中的作用
Nov 14 Python
python-opencv 将连续图片写成视频格式的方法
Jan 08 Python
Python如何筛选序列中的元素的方法实现
Jul 15 Python
pywinauto自动化操作记事本
Aug 26 Python
Python udp网络程序实现发送、接收数据功能示例
Dec 09 Python
Pycharm 2020年最新激活码(亲测有效)
Sep 18 Python
python网络编程:socketserver的基本使用方法实例分析
Apr 09 Python
Python pysnmp使用方法及代码实例
Aug 24 Python
基于Python组装jmx并调用JMeter实现压力测试
Nov 03 Python
Restful_framework视图组件代码实例解析
Nov 17 Python
keras 如何保存最佳的训练模型
May 25 #Python
keras处理欠拟合和过拟合的实例讲解
May 25 #Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
You might like
PHP添加文字水印或图片水印的水印类完整源代码与使用示例
2019/03/18 PHP
PHP实现简单的协程任务调度demo示例
2020/02/01 PHP
javascript 45种缓动效果 非常酷
2011/06/28 Javascript
工作需要写的一个js拖拽组件
2011/07/28 Javascript
UI Events 用户界面事件
2012/06/27 Javascript
5秒后跳转到另一个页面的js代码
2013/10/12 Javascript
js动态修改整个页面样式达到换肤效果
2014/05/23 Javascript
老生常谈 js中this的指向
2016/06/30 Javascript
javaScript给元素添加多个class的简单实现
2016/07/20 Javascript
JavaScript中附件预览功能实现详解(推荐)
2017/08/15 Javascript
js实现div色块碰撞
2020/01/16 Javascript
JQuery复选框全选效果如何实现
2020/05/08 jQuery
[48:52]DOTA2上海特级锦标赛A组小组赛#2 Secret VS CDEC第一局
2016/02/25 DOTA
[02:33]DOTA2亚洲邀请赛趣味视频之吐真话筒
2018/03/31 DOTA
[01:09:40]Newbee vs Pain 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
利用Celery实现Django博客PV统计功能详解
2017/05/08 Python
浅谈python连续赋值可能引发的错误
2018/11/10 Python
PyQt5实现五子棋游戏(人机对弈)
2020/03/24 Python
python 实现二维字典的键值合并等函数
2019/12/06 Python
Django 构建模板form表单的两种方法
2020/06/14 Python
python中return如何写
2020/06/18 Python
Python:__eq__和__str__函数的使用示例
2020/09/26 Python
HTML5超文本标记语言的实现方法
2020/09/24 HTML / CSS
LN-CC美国:伦敦时尚生活的缩影
2019/02/19 全球购物
护理专业自荐信范文
2014/02/26 职场文书
《孙权劝学》教学反思
2014/04/23 职场文书
《猴子种果树》教学反思
2014/04/26 职场文书
销售个人求职信范文
2014/04/28 职场文书
服装仓管员岗位职责
2014/06/17 职场文书
企业党建工作汇报材料
2014/08/19 职场文书
2014年惩防体系建设工作总结
2014/12/01 职场文书
食堂采购员岗位职责
2015/04/03 职场文书
小学校园广播稿
2015/08/18 职场文书
Python实战之实现简易的学生选课系统
2021/05/25 Python
Python干货实战之八音符酱小游戏全过程详解
2021/10/24 Python
深入讲解数据库中Decimal类型的使用以及实现方法
2022/02/15 MySQL