keras模型保存为tensorflow的二进制模型方式


Posted in Python onMay 25, 2020

最近需要将使用keras训练的模型移植到手机上使用, 因此需要转换到tensorflow的二进制模型。

折腾一下午,终于找到一个合适的方法,废话不多说,直接上代码:

# coding=utf-8
import sys

from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 """
 Freezes the state of a session into a prunned computation graph.

 Creates a new computation graph where variable nodes are replaced by
 constants taking their current value in the session. The new graph will be
 prunned so subgraphs that are not neccesary to compute the requested
 outputs are removed.
 @param session The TensorFlow session to be frozen.
 @param keep_var_names A list of variable names that should not be frozen,
       or None to freeze all the variables in the graph.
 @param output_names Names of the relevant graph outputs.
 @param clear_devices Remove the device directives from the graph for better portability.
 @return The frozen graph definition.
 """
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
              output_names, freeze_var_names)
  return frozen_graph

input_fld = sys.path[0]
weight_file = 'your_model.h5'
output_graph_name = 'tensor_model.pb'

output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
 os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)

print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])

from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

上面代码实现保存到当前目录的tensor_model目录下。

验证:

import tensorflow as tf
import numpy as np
import PIL.Image as Image
import cv2

def recognize(jpg_path, pb_file_path):
 with tf.Graph().as_default():
  output_graph_def = tf.GraphDef()

  with open(pb_file_path, "rb") as f:
   output_graph_def.ParseFromString(f.read())
   tensors = tf.import_graph_def(output_graph_def, name="")
   print tensors

  with tf.Session() as sess:
   init = tf.global_variables_initializer()
   sess.run(init)

   op = sess.graph.get_operations()
   
   for m in op:
    print(m.values())

   input_x = sess.graph.get_tensor_by_name("convolution2d_1_input:0") #具体名称看上一段代码的input.name
   print input_x

   out_softmax = sess.graph.get_tensor_by_name("activation_4/Softmax:0") #具体名称看上一段代码的output.name

   print out_softmax

   img = cv2.imread(jpg_path, 0)
   img_out_softmax = sess.run(out_softmax,
          feed_dict={input_x: 1.0 - np.array(img).reshape((-1,28, 28, 1)) / 255.0})

   print "img_out_softmax:", img_out_softmax
   prediction_labels = np.argmax(img_out_softmax, axis=1)
   print "label:", prediction_labels

pb_path = 'tensorflow_model/constant_graph_weights.pb'
img = 'test/6/8_48.jpg'
recognize(img, pb_path)

补充知识:如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlow serving环境调用

首先keras训练好的模型通过自带的model.save()保存下来是 .model (.h5) 格式的文件

模型载入是通过 my_model = keras . models . load_model( filepath )

要将该模型转换为.pb 格式的TensorFlow 模型,代码如下:

# -*- coding: utf-8 -*-
from keras.layers.core import Activation, Dense, Flatten
from keras.layers.embeddings import Embedding
from keras.layers.recurrent import LSTM
from keras.layers import Dropout
from keras.layers.wrappers import Bidirectional
from keras.models import Sequential,load_model
from keras.preprocessing import sequence
from sklearn.model_selection import train_test_split
import collections
from collections import defaultdict
import jieba
import numpy as np
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
              output_names, freeze_var_names)
  return frozen_graph
input_fld = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/'
weight_file = 'biLSTM_brand_recognize.model'
output_graph_name = 'tensor_model_v3.pb'

output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
 os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)

print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=True)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

然后模型就存成了.pb格式的文件

问题就来了,这样存下来的.pb格式的文件是frozen model

如果通过TensorFlow serving 启用模型的话,会报错:

E tensorflow_serving/core/aspired_versions_manager.cc:358] Servable {name: mnist version: 1} cannot be loaded: Not found: Could not find meta graph def matching supplied tags: { serve }. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: `saved_model_cli`

因为TensorFlow serving 希望读取的是saved model

于是需要将frozen model 转化为 saved model 格式,解决方案如下:

from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants

export_dir = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/saved_model'
graph_pb = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/tensorflow_model/tensor_model.pb'

builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

with tf.gfile.GFile(graph_pb, "rb") as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())

sigs = {}

with tf.Session(graph=tf.Graph()) as sess:
 # name="" is important to ensure we don't get spurious prefixing
 tf.import_graph_def(graph_def, name="")
 g = tf.get_default_graph()
 inp = g.get_tensor_by_name(net_model.input.name)
 out = g.get_tensor_by_name(net_model.output.name)

 sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
  tf.saved_model.signature_def_utils.predict_signature_def(
   {"in": inp}, {"out": out})

 builder.add_meta_graph_and_variables(sess,
           [tag_constants.SERVING],
           signature_def_map=sigs)
builder.save()

于是保存下来的saved model 文件夹下就有两个文件:

saved_model.pb variables

其中variables 可以为空

于是将.pb 模型导入serving再读取,成功!

以上这篇keras模型保存为tensorflow的二进制模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现监控windows服务并自动启动服务示例
Apr 17 Python
零基础写python爬虫之爬虫框架Scrapy安装配置
Nov 06 Python
Python中函数参数设置及使用的学习笔记
May 03 Python
Python 判断 有向图 是否有环的实例讲解
Feb 01 Python
Python多进程fork()函数详解
Feb 22 Python
Python 3 实现定义跨模块的全局变量和使用教程
Jul 07 Python
python执行scp命令拷贝文件及文件夹到远程主机的目录方法
Jul 08 Python
在SQLite-Python中实现返回、查询中文字段的方法
Jul 17 Python
Python 解决相对路径问题:"No such file or directory"
Jun 05 Python
Python模块常用四种安装方式
Oct 20 Python
Python中免验证跳转到内容页的实例代码
Oct 23 Python
Python中的socket网络模块介绍
Jul 23 Python
keras 如何保存最佳的训练模型
May 25 #Python
keras处理欠拟合和过拟合的实例讲解
May 25 #Python
python如何调用字典的key
May 25 #Python
如何使用python的ctypes调用医保中心的dll动态库下载医保中心的账单
May 24 #Python
Python+PyQt5实现灭霸响指功能
May 25 #Python
PyQt5实现仿QQ贴边隐藏功能的实例代码
May 24 #Python
通过Python扫描代码关键字并进行预警的实现方法
May 24 #Python
You might like
PHP中的加密功能
2006/10/09 PHP
PHP面向对象之里氏替换原则简单示例
2018/04/08 PHP
thinkPHP5框架自定义验证器实现方法分析
2018/06/11 PHP
基于PHP实现用户登录注册功能的详细教程
2020/08/04 PHP
硬盘浏览程序,保存成网页格式便可使用
2006/12/03 Javascript
javascript 限制输入和粘贴(IE,firefox测试通过)
2008/11/14 Javascript
基于jQuery的为attr添加id title等效果的实现代码
2011/04/20 Javascript
javascript实现颜色渐变的方法
2013/10/30 Javascript
JQuery获取与设置HTML元素的内容或文本的实现代码
2014/06/20 Javascript
js仿土豆网带缩略图的焦点图片切换效果实现方法
2015/02/23 Javascript
CSS图片响应式 垂直水平居中
2015/08/14 Javascript
JavaScript中的prototype原型学习指南
2016/05/09 Javascript
jquery表单插件Autotab使用方法详解
2016/06/24 Javascript
JS正则获取HTML元素的方法
2017/03/31 Javascript
使用重写url机制实现验证码换一张功能
2017/08/01 Javascript
详解nodeJs文件系统(fs)与流(stream)
2018/01/24 NodeJs
Vue组件通信的四种方式汇总
2018/02/08 Javascript
webstorm+vue初始化项目的方法
2018/10/18 Javascript
小程序scroll-view安卓机隐藏横向滚动条的实现详解
2019/05/16 Javascript
微信小程序保持session会话的方法
2020/03/20 Javascript
[06:15]2016国际邀请赛中国区预选赛单车采访:我顶WINGS
2016/06/27 DOTA
[01:07:41]IG vs VGJ.T 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
[01:11:21]DOTA2-DPC中国联赛 正赛 VG vs Elephant BO3 第一场 3月6日
2021/03/11 DOTA
Python Web框架Flask信号机制(signals)介绍
2015/01/01 Python
Python中set与frozenset方法和区别详解
2016/05/23 Python
pygame实现贪吃蛇游戏(上)
2019/10/29 Python
很酷的HTML5电子书翻页动画特效
2016/02/25 HTML / CSS
历史学专业大学生找工作的自我评价
2013/10/16 职场文书
志愿者服务感言
2014/02/27 职场文书
内勤岗位职责
2015/02/10 职场文书
2015年酒店年度工作总结
2015/05/23 职场文书
2015暑假实习报告范文
2015/07/13 职场文书
感恩教师主题班会
2015/08/12 职场文书
深度学习详解之初试机器学习
2021/04/14 Python
交互式可视化js库gojs使用介绍及技巧
2022/02/18 Javascript
Redis 异步机制
2022/05/15 Redis