将keras的h5模型转换为tensorflow的pb模型操作


Posted in Python onMay 25, 2020

背景:目前keras框架使用简单,很容易上手,深得广大算法工程师的喜爱,但是当部署到客户端时,可能会出现各种各样的bug,甚至不支持使用keras,本文来解决的是将keras的h5模型转换为客户端常用的tensorflow的pb模型并使用tensorflow加载pb模型。

h5_to_pb.py
 
from keras.models import load_model
import tensorflow as tf
import os 
import os.path as osp
from keras import backend as K
#路径参数
input_path = 'input path'
weight_file = 'weight.h5'
weight_file_path = osp.join(input_path,weight_file)
output_graph_name = weight_file[:-3] + '.pb'
#转换函数
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True):
  if osp.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i],out_prefix + str(i + 1))
  sess = K.get_session()
  from tensorflow.python.framework import graph_util,graph_io
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
  graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir)
#输出路径
output_dir = osp.join(os.getcwd(),"trans_model")
#加载模型
h5_model = load_model(weight_file_path)
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name)
print('model saved')

将转换成的pb模型进行加载

load_pb.py
 
import tensorflow as tf
from tensorflow.python.platform import gfile
 
def load_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
 
  print(sess.run('b:0'))
  #输入
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  #输出
  op = sess.graph.get_tensor_by_name('op_to_store:0')
  #预测结果
  ret = sess.run(op, {input_x: 3, input_y: 4})
  print(ret)

补充知识:h5模型转化为pb模型,代码及排坑

我是在实际工程中要用到tensorflow训练的pb模型,但是训练的代码是用keras写的,所以生成keras特定的h5模型,所以用到了h5_to_pb.py函数。

附上h5_to_pb.py(python3)

#*-coding:utf-8-*

"""
将keras的.h5的模型文件,转换成TensorFlow的pb文件
"""
# ==========================================================

from keras.models import load_model
import tensorflow as tf
import os.path as osp
import os
from keras import backend
#from keras.models import Sequential

def h5_to_pb(h5_model, output_dir, model_name, out_prefix="output_", log_tensorboard=True):
  """.h5模型文件转换成pb模型文件
  Argument:
    h5_model: str
      .h5模型文件
    output_dir: str
      pb模型文件保存路径
    model_name: str
      pb模型文件名称
    out_prefix: str
      根据训练,需要修改
    log_tensorboard: bool
      是否生成日志文件
  Return:
    pb模型文件
  """
  if os.path.exists(output_dir) == False:
    os.mkdir(output_dir)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  sess = backend.get_session()

  from tensorflow.python.framework import graph_util, graph_io
  # 写入pb模型文件
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  # 输出日志文件
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

if __name__ == '__main__':
  # .h模型文件路径参数
  input_path = 'D:/CSP'
  weight_file = 'xingren.h5'
  weight_file_path = os.path.join(input_path, weight_file)
  output_graph_name = weight_file[:-3] + '.pb'

  # pb模型文件输出输出路径
  output_dir = osp.join(os.getcwd(),"trans_model")
  #model.save(xingren.h5)
  # 加载模型
  #h5_model = Sequential()
  h5_model = load_model(weight_file_path)
  #h5_model.save(weight_file_path)
  #h5_model.save('xingren.h5')
  h5_to_pb(h5_model, output_dir=output_dir, model_name=output_graph_name)
  print ('Finished')

在运行的时候遇到了下面问题:

将keras的h5模型转换为tensorflow的pb模型操作

原因:我们训练模型的时候用save_weights函数保存模型,但是这个函数只保存了权重文件,并没有又保存模型的参数。要把save_weights改为save。

下边是两个函数介绍:

save()保存的模型结果,它既保持了模型的图结构,又保存了模型的参数。

save_weights()保存的模型结果,它只保存了模型的参数,但并没有保存模型的图结构

以上这篇将keras的h5模型转换为tensorflow的pb模型操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python判断IP地址合法性的方法实例
Mar 13 Python
python实现模拟按键,自动翻页看u17漫画
Mar 17 Python
python3使用PyMysql连接mysql数据库实例
Feb 07 Python
python+selenium实现京东自动登录及秒杀功能
Nov 18 Python
Django 使用Ajax进行前后台交互的示例讲解
May 28 Python
Python实现读取机器硬件信息的方法示例
Jun 09 Python
Pycharm 文件更改目录后,执行路径未更新的解决方法
Jul 19 Python
python 3.7.4 安装 opencv的教程
Oct 10 Python
python通过matplotlib生成复合饼图
Feb 06 Python
Python各种扩展名区别点整理
Feb 27 Python
Django-xadmin后台导入json数据及后台显示信息图标和主题更改方式
Mar 11 Python
Python如何将将模块分割成多个文件
Aug 04 Python
tensorflow转换ckpt为savermodel模型的实现
May 25 #Python
基于Python把网站域名解析成ip地址
May 25 #Python
使用keras和tensorflow保存为可部署的pb格式
May 25 #Python
Python使用configparser读取ini配置文件
May 25 #Python
浅谈tensorflow模型保存为pb的各种姿势
May 25 #Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 #Python
keras模型保存为tensorflow的二进制模型方式
May 25 #Python
You might like
PHILIPS L4X25T电路分析和打理
2021/03/02 无线电
php.ini中文版
2006/10/09 PHP
PHP个人网站架设连环讲(四)
2006/10/09 PHP
用PHP和ACCESS写聊天室(四)
2006/10/09 PHP
PHP循环结构实例讲解
2014/02/10 PHP
解析Jquery取得iframe中元素的几种方法
2013/07/04 Javascript
css与javascript跨浏览器兼容性总结
2014/09/15 Javascript
Nodejs学习笔记之Stream模块
2015/01/13 NodeJs
JavaScript实现的简单幂函数实例
2015/04/17 Javascript
JavaScript截取、切割字符串的技巧
2016/01/07 Javascript
利用javascript实现的三种图片放大镜效果实例(附源码)
2017/01/23 Javascript
最常用的jQuery表单验证(简单)
2017/05/23 jQuery
Vue实现路由跳转和嵌套
2017/06/20 Javascript
NodeJs通过async/await处理异步的方法
2017/10/09 NodeJs
结合Vue控制字符和字节的显示个数的示例
2018/05/17 Javascript
微信小程序与后台PHP交互的方法实例分析
2018/12/10 Javascript
详解超简单的react服务器渲染(ssr)入坑指南
2019/02/28 Javascript
ES6中的迭代器、Generator函数及Generator函数的异步操作方法
2019/05/12 Javascript
[01:13]2015国际邀请赛线下观战现场
2015/08/08 DOTA
[48:24]完美世界DOTA2联赛PWL S3 Forest vs INK ICE 第一场 12.09
2020/12/12 DOTA
[41:52]DOTA2-DPC中国联赛 正赛 CDEC vs Dynasty BO3 第二场 2月22日
2021/03/11 DOTA
windows下wxPython开发环境安装与配置方法
2014/06/28 Python
Python中.join()和os.path.join()两个函数的用法详解
2018/06/11 Python
Python多进程原理与用法分析
2018/08/21 Python
python脚本调用iftop 统计业务应用流量的思路详解
2019/10/11 Python
关于ZeroMQ 三种模式python3实现方式
2019/12/23 Python
pytorch 中的重要模块化接口nn.Module的使用
2020/04/02 Python
python Matplotlib模块的使用
2020/09/16 Python
美国运动鞋和服装网上商店:YCMC
2018/09/15 全球购物
Etam艾格英国官网:法国著名女装品牌
2019/04/15 全球购物
盛大笔试题
2016/11/05 面试题
《大江保卫战》教学反思
2014/04/11 职场文书
党课培训心得体会
2014/09/02 职场文书
纪念九一八事变演讲稿1000字
2014/09/14 职场文书
2016年社区“6.26”禁毒日宣传活动总结
2016/04/05 职场文书
解决jupyter notebook启动后没有token的坑
2021/04/24 Python