将keras的h5模型转换为tensorflow的pb模型操作


Posted in Python onMay 25, 2020

背景:目前keras框架使用简单,很容易上手,深得广大算法工程师的喜爱,但是当部署到客户端时,可能会出现各种各样的bug,甚至不支持使用keras,本文来解决的是将keras的h5模型转换为客户端常用的tensorflow的pb模型并使用tensorflow加载pb模型。

h5_to_pb.py
 
from keras.models import load_model
import tensorflow as tf
import os 
import os.path as osp
from keras import backend as K
#路径参数
input_path = 'input path'
weight_file = 'weight.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
  if osp.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i],out_prefix + str(i + 1))
  sess = K.get_session()
  from tensorflow.python.framework import graph_util,graph_io
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
  graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#输出路径
output_dir = osp.join(os.getcwd(),"trans_model")
#加载模型
h5_model = load_model(weight_file_path)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')

将转换成的pb模型进行加载

load_pb.py
 
import tensorflow as tf
from tensorflow.python.platform import gfile
 
def load_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
 
  print(sess.run('b:0'))
  #输入
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  #输出
  op = sess.graph.get_tensor_by_name('op_to_store:0')
  #预测结果
  ret = sess.run(op, {input_x: 3, input_y: 4})
  print(ret)

补充知识:h5模型转化为pb模型,代码及排坑

我是在实际工程中要用到tensorflow训练的pb模型,但是训练的代码是用keras写的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函数。

附上h5_to_pb.py(python3)

#*-coding:utf-8-*

"""
将keras的.h5的模型文件,转换成TensorFlow的pb文件
"""
# ==========================================================

from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend
#from keras.models import Sequential

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
  """.h5模型文件转换成pb模型文件
  Argument:
    h5_model: str
      .h5模型文件
    output_dir: str
      pb模型文件保存路径
    model_name: str
      pb模型文件名称
    out_prefix: str
      根据训练,需要修改
    log_tensorboard: bool
      是否生成日志文件
  Return:
    pb模型文件
  """
  if os.path.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  sess = backend.get_session()

  from tensorflow.python.framework import graph_util, graph_io
  # 写入pb模型文件
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  # 输出日志文件
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

if __name__ == '__main__':
  # .h模型文件路径参数
  input_path = 'D:/CSP'
  weight_file = 'xingren.h5'
  weight_file_path = os.path.join(input_path, weight_file)
  output_graph_name = weight_file[:-3] + '.pb'

  # pb模型文件输出输出路径
  output_dir = osp.join(os.getcwd(),"trans_model")
  #model.save(xingren.h5)
  # 加载模型
  #h5_model = Sequential()
  h5_model = load_model(weight_file_path)
  #h5_model.save(weight_file_path)
  #h5_model.save('xingren.h5')
  h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
  print ('Finished')

在运行的时候遇到了下面问题:

将keras的h5模型转换为tensorflow的pb模型操作

原因:我们训练模型的时候用save_weights函数保存模型,但是这个函数只保存了权重文件,并没有又保存模型的参数。要把save_weights改为save。

下边是两个函数介绍:

save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。

save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构

以上这篇将keras的h5模型转换为tensorflow的pb模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python代码解决RenderView窗口not found问题
Aug 28 Python
用tensorflow实现弹性网络回归算法
Jan 09 Python
django 使用 request 获取浏览器发送的参数示例代码
Jun 11 Python
PyCharm在新窗口打开项目的方法
Jan 17 Python
全面了解django的缓存机制及使用方法
Jul 22 Python
详解python列表(list)的使用技巧及高级操作
Aug 15 Python
Python帮你微信头像任意添加装饰别再@微信官方了
Sep 25 Python
利用Python校准本地时间的方法教程
Oct 31 Python
基于python实现微信好友数据分析(简单)
Feb 16 Python
python能做哪些生活有趣的事情
Sep 09 Python
Matplotlib可视化之添加让统计图变得简单易懂的注释
Jun 11 Python
python生成可执行exe控制Microsip自动填写号码并拨打功能
Jun 21 Python
tensorflow转换ckpt为savermodel模型的实现
May 25 #Python
基于Python把网站域名解析成ip地址
May 25 #Python
使用keras和tensorflow保存为可部署的pb格式
May 25 #Python
Python使用configparser读取ini配置文件
May 25 #Python
浅谈tensorflow模型保存为pb的各种姿势
May 25 #Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
You might like
php Rename 更改文件、文件夹名称
2011/05/24 PHP
php通过正则表达式记取数据来读取xml的方法
2015/03/09 PHP
PHP图形计数器程序显示网站用户浏览量
2016/07/20 PHP
PHP自定义递归函数实现数组转JSON功能【支持GBK编码】
2018/07/17 PHP
JavaScript高级程序设计 客户端存储学习笔记
2011/09/10 Javascript
js事件监听机制(事件捕获)总结
2014/08/08 Javascript
jquery实现点击页面计算点击次数
2015/01/23 Javascript
jQuery插件ajaxFileUpload实现异步上传文件效果
2015/04/14 Javascript
js实现图片轮播效果
2015/12/19 Javascript
javascript正则表达式定义(语法)总结
2016/01/08 Javascript
Angular外部使用js调用Angular控制器中的函数方法或变量用法示例
2016/08/05 Javascript
JS控制静态页面传递参数并获取参数应用
2016/08/10 Javascript
js中scrollTop()方法和scroll()方法用法示例
2016/10/03 Javascript
JS DOMReady事件的六种实现方法总结
2016/11/23 Javascript
如何编写一个完整的Angular4 FormText 组件
2017/11/18 Javascript
在vue中添加Echarts图表的基本使用教程
2017/11/22 Javascript
JS中使用react-tooltip插件实现鼠标悬浮显示框
2019/05/15 Javascript
js最全的数组的降维5种办法(小结)
2020/04/28 Javascript
下载糗事百科的内容_python版
2008/12/07 Python
在Python中定义和使用抽象类的方法
2016/06/30 Python
PyQt5每天必学之QSplitter实现窗口分隔
2018/04/19 Python
Django rest framework工具包简单用法示例
2018/07/20 Python
Flask Web开发入门之文件上传(八)
2018/08/17 Python
python3.x 生成3维随机数组实例
2019/11/28 Python
python实现门限回归方式
2020/02/29 Python
python使用selenium爬虫知乎的方法示例
2020/10/28 Python
python软件测试Jmeter性能测试JDBC Request(结合数据库)的使用详解
2021/01/26 Python
Book Depository澳大利亚:世界领先的专业在线书店之一
2018/12/27 全球购物
sort命令的作用和用法
2012/11/04 面试题
内部类的定义、种类以及优点
2013/10/16 面试题
大学自主招生自荐信
2013/12/16 职场文书
自我评价的写作规则
2014/01/06 职场文书
英语一分钟演讲稿
2014/04/29 职场文书
志愿者爱心公益活动策划方案
2014/09/15 职场文书
查摆问题整改措施
2014/10/24 职场文书
监理中标通知书
2015/04/16 职场文书