将keras的h5模型转换为tensorflow的pb模型操作


Posted in Python onMay 25, 2020

背景:目前keras框架使用简单,很容易上手,深得广大算法工程师的喜爱,但是当部署到客户端时,可能会出现各种各样的bug,甚至不支持使用keras,本文来解决的是将keras的h5模型转换为客户端常用的tensorflow的pb模型并使用tensorflow加载pb模型。

h5_to_pb.py
 
from keras.models import load_model
import tensorflow as tf
import os 
import os.path as osp
from keras import backend as K
#路径参数
input_path = 'input path'
weight_file = 'weight.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
  if osp.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i],out_prefix + str(i + 1))
  sess = K.get_session()
  from tensorflow.python.framework import graph_util,graph_io
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
  graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#输出路径
output_dir = osp.join(os.getcwd(),"trans_model")
#加载模型
h5_model = load_model(weight_file_path)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')

将转换成的pb模型进行加载

load_pb.py
 
import tensorflow as tf
from tensorflow.python.platform import gfile
 
def load_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
 
  print(sess.run('b:0'))
  #输入
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  #输出
  op = sess.graph.get_tensor_by_name('op_to_store:0')
  #预测结果
  ret = sess.run(op, {input_x: 3, input_y: 4})
  print(ret)

补充知识:h5模型转化为pb模型,代码及排坑

我是在实际工程中要用到tensorflow训练的pb模型,但是训练的代码是用keras写的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函数。

附上h5_to_pb.py(python3)

#*-coding:utf-8-*

"""
将keras的.h5的模型文件,转换成TensorFlow的pb文件
"""
# ==========================================================

from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend
#from keras.models import Sequential

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
  """.h5模型文件转换成pb模型文件
  Argument:
    h5_model: str
      .h5模型文件
    output_dir: str
      pb模型文件保存路径
    model_name: str
      pb模型文件名称
    out_prefix: str
      根据训练,需要修改
    log_tensorboard: bool
      是否生成日志文件
  Return:
    pb模型文件
  """
  if os.path.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  sess = backend.get_session()

  from tensorflow.python.framework import graph_util, graph_io
  # 写入pb模型文件
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  # 输出日志文件
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

if __name__ == '__main__':
  # .h模型文件路径参数
  input_path = 'D:/CSP'
  weight_file = 'xingren.h5'
  weight_file_path = os.path.join(input_path, weight_file)
  output_graph_name = weight_file[:-3] + '.pb'

  # pb模型文件输出输出路径
  output_dir = osp.join(os.getcwd(),"trans_model")
  #model.save(xingren.h5)
  # 加载模型
  #h5_model = Sequential()
  h5_model = load_model(weight_file_path)
  #h5_model.save(weight_file_path)
  #h5_model.save('xingren.h5')
  h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
  print ('Finished')

在运行的时候遇到了下面问题:

将keras的h5模型转换为tensorflow的pb模型操作

原因:我们训练模型的时候用save_weights函数保存模型,但是这个函数只保存了权重文件,并没有又保存模型的参数。要把save_weights改为save。

下边是两个函数介绍:

save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。

save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构

以上这篇将keras的h5模型转换为tensorflow的pb模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
windows系统下Python环境的搭建(Aptana Studio)
Mar 06 Python
详解Python多线程Selenium跨浏览器测试
Apr 01 Python
老生常谈Python之装饰器、迭代器和生成器
Jul 26 Python
Python编程给numpy矩阵添加一列方法示例
Dec 04 Python
Python统计单词出现的次数
Apr 04 Python
Python+OpenCV图片局部区域像素值处理详解
Jan 23 Python
对Python强大的可变参数传递机制详解
Jun 13 Python
django表单的Widgets使用详解
Jul 22 Python
python运用sklearn实现KNN分类算法
Oct 16 Python
Python random库使用方法及异常处理方案
Mar 02 Python
如何搭建pytorch环境的方法步骤
May 06 Python
在Keras中实现保存和加载权重及模型结构
Jun 15 Python
tensorflow转换ckpt为savermodel模型的实现
May 25 #Python
基于Python把网站域名解析成ip地址
May 25 #Python
使用keras和tensorflow保存为可部署的pb格式
May 25 #Python
Python使用configparser读取ini配置文件
May 25 #Python
浅谈tensorflow模型保存为pb的各种姿势
May 25 #Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
You might like
Zend Studio (eclipse)使用速度优化方法
2011/03/23 PHP
php使用scandir()函数扫描指定目录下所有文件示例
2019/06/08 PHP
PHP实现一个按钮点击上传多个图片操作示例
2020/01/23 PHP
Javascript简单实现可拖动的div
2013/10/22 Javascript
JavaScript中的prototype和constructor简明总结
2014/04/05 Javascript
jQuery插件开发详细教程
2014/06/06 Javascript
javascript中JSON对象与JSON字符串相互转换实例
2015/07/11 Javascript
基于jquery实现省市联动特效
2015/12/17 Javascript
如何通过js实现图片预览功能【附实例代码】
2016/03/30 Javascript
JS代码随机生成姓名、手机号、身份证号、银行卡号
2016/04/27 Javascript
浅谈Angular.js中使用$watch监听模型变化
2017/01/10 Javascript
Avalonjs 实现简单购物车功能(实例代码)
2017/02/07 Javascript
JS中正则表达式全局匹配模式 /g用法详解
2017/04/01 Javascript
jquery引入外部CDN 加载失败则引入本地jq库
2018/05/23 jQuery
解决一个微信号同时支持多个环境网页授权问题
2019/08/07 Javascript
JavaScript实现滑动门效果
2020/01/18 Javascript
微信小程序搜索框样式并实现跳转到搜索页面(小程序搜索功能)
2020/03/10 Javascript
JavaScript进阶(一)变量声明提升实例分析
2020/05/09 Javascript
python实现ID3决策树算法
2017/12/20 Python
python实现监控阿里云账户余额功能
2019/12/16 Python
在pycharm中使用pipenv创建虚拟环境和安装django的详细教程
2020/11/30 Python
利用Python实现最小二乘法与梯度下降算法
2021/02/21 Python
CSS3实现莲花绽放的动画效果
2020/11/06 HTML / CSS
html5实现多文件的上传示例代码
2014/02/13 HTML / CSS
REISS美国官网:伦敦最受欢迎的时尚品牌
2019/08/16 全球购物
Prototype是怎么扩展DOM的
2014/10/01 面试题
Java中的类包括什么内容?设计时要注意哪些方面
2012/05/23 面试题
国贸专业大学生职业生涯规划范文
2014/01/10 职场文书
报到证丢失证明
2014/01/11 职场文书
车间安全生产标语
2014/06/06 职场文书
考研英语复习计划
2015/01/19 职场文书
电子商务专业求职信范文
2015/03/19 职场文书
2015年“公民道德宣传日”活动方案
2015/05/06 职场文书
教师素质教育心得体会
2016/01/19 职场文书
2016年学校爱国卫生月活动总结
2016/04/06 职场文书
Flutter Navigator 实现路由传递参数
2022/04/22 Java/Android