Keras模型转成tensorflow的.pb操作


Posted in Python onJuly 06, 2020

Keras的.h5模型转成tensorflow的.pb格式模型,方便后期的前端部署。直接上代码

from keras.models import Model
from keras.layers import Dense, Dropout
from keras.applications.mobilenet import MobileNet
from keras.applications.mobilenet import preprocess_input
from keras.preprocessing.image import load_img, img_to_array
import tensorflow as tf
from keras import backend as K
import os
 
base_model = MobileNet((None, None, 3), alpha=1, include_top=False, pooling='avg', weights=None)
x = Dropout(0.75)(base_model.output)
x = Dense(10, activation='softmax')(x)
 
model = Model(base_model.input, x)
model.load_weights('mobilenet_weights.h5')
 
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
 from tensorflow.python.framework.graph_util import convert_variables_to_constants
 graph = session.graph
 with graph.as_default():
  freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
  output_names = output_names or []
  output_names += [v.op.name for v in tf.global_variables()]
  input_graph_def = graph.as_graph_def()
  if clear_devices:
   for node in input_graph_def.node:
    node.device = ""
  frozen_graph = convert_variables_to_constants(session, input_graph_def,
             output_names, freeze_var_names)
  return frozen_graph
 
output_graph_name = 'NIMA.pb'
output_fld = ''
#K.set_learning_phase(0)
 
print('input is :', model.input.name)
print ('output is:', model.output.name)
 
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name])
 
from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
print('saved the constant graph (ready for inference) at: ', os.path.join(output_fld, output_graph_name))

补充知识:keras h5 model 转换为tflite

在移动端的模型,若选择tensorflow或者keras最基本的就是生成tflite文件,以本文记录一次转换过程。

环境

tensorflow 1.12.0

python 3.6.5

h5 model saved by `model.save('tf.h5')`

直接转换

`tflite_convert --output_file=tf.tflite --keras_model_file=tf.h5`
output
`TypeError: __init__() missing 2 required positional arguments: 'filters' and 'kernel_size'`

先转成pb再转tflite

```

git clone git@github.com:amir-abdi/keras_to_tensorflow.git
cd keras_to_tensorflow
python keras_to_tensorflow.py --input_model=path/to/tf.h5 --output_model=path/to/tf.pb
tflite_convert \

 --output_file=tf.tflite \
 --graph_def_file=tf.pb \
 --input_arrays=convolution2d_1_input \
 --output_arrays=dense_3/BiasAdd \
 --input_shape=1,3,448,448
```

参数说明,input_arrays和output_arrays是model的起始输入变量名和结束变量名,input_shape是和input_arrays对应

官网是说需要用到tenorboard来查看,一个比较trick的方法

先执行上面的命令,会报convolution2d_1_input找不到,在堆栈里面有convert_saved_model.py文件,get_tensors_from_tensor_names()这个方法,添加`print(list(tensor_name_to_tensor))` 到 tensor_name_to_tensor 这个变量下面,再执行一遍,会打印出所有tensor的名字,再根据自己的模型很容易就能判断出实际的name。

以上这篇Keras模型转成tensorflow的.pb操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用SimpleParse模块进行解析的教程
Apr 11 Python
python实现根据主机名字获得所有ip地址的方法
Jun 28 Python
python 删除大文件中的某一行(最有效率的方法)
Aug 19 Python
解决Spyder中图片显示太小的问题
Apr 27 Python
python leetcode 字符串相乘实例详解
Sep 03 Python
python3 读取Excel表格中的数据
Oct 16 Python
Python文件操作中进行字符串替换的方法(保存到新文件/当前文件)
Jun 28 Python
pycharm新建一个python工程步骤
Jul 16 Python
Python3 socket即时通讯脚本实现代码实例(threading多线程)
Jun 01 Python
记一次Django响应超慢的解决过程
Sep 17 Python
python程序的组织结构详解
Dec 06 Python
Django框架中表单的用法
Jun 10 Python
python如何进入交互模式
Jul 06 #Python
python3.4中清屏的处理方法
Jul 06 #Python
Python3基于print打印带颜色字符串
Jul 06 #Python
python判断是空的实例分享
Jul 06 #Python
python三引号如何输入
Jul 06 #Python
如何验证python安装成功
Jul 06 #Python
使用Keras训练好的.h5模型来测试一个实例
Jul 06 #Python
You might like
PHP编程函数安全篇
2013/01/08 PHP
PHP正则替换函数preg_replace和preg_replace_callback使用总结
2014/09/22 PHP
php数组生成html下拉列表的方法
2015/07/20 PHP
PHP设计模式之简单工厂和工厂模式实例分析
2019/03/25 PHP
最短的IE判断代码
2011/03/13 Javascript
转义字符(\)对JavaScript中JSON.parse的影响概述
2013/07/17 Javascript
JS实现模仿微博发布效果实例代码
2013/12/16 Javascript
javascript实现存储hmtl字符串示例
2014/04/25 Javascript
DOM节点的替换或修改函数replaceChild()用法实例
2015/01/12 Javascript
jQuery中unwrap()方法用法实例
2015/01/16 Javascript
jQuery中设置form表单中action值的实现方法
2016/05/25 Javascript
JS实现iframe自适应高度的方法(兼容IE与FireFox)
2016/06/24 Javascript
概述javascript在Google IE中的调试技巧
2016/11/24 Javascript
JS实现页面打印(整体、局部)
2017/08/18 Javascript
详解vue-router 路由元信息
2017/09/13 Javascript
详解创建自定义的Angular Schematics
2018/06/06 Javascript
基于JavaScript实现每日签到打卡轨迹功能
2018/11/29 Javascript
vue-cli3添加模式配置多环境变量的方法
2019/06/05 Javascript
layer.msg()去掉默认时间,实现手动关闭的方法
2019/09/12 Javascript
JavaScript实现省市联动效果
2019/11/22 Javascript
使用vue重构资讯页面的实例代码解析
2019/11/26 Javascript
利用PHP实现递归删除链表元素的方法示例
2020/10/23 Javascript
[01:13:08]2018DOTA2亚洲邀请赛4.6 淘汰赛 mineski vs LGD 第二场
2018/04/10 DOTA
python 视频逐帧保存为图片的完整实例
2019/12/10 Python
手把手教你安装Windows版本的Tensorflow
2020/03/26 Python
纯CSS3实现滚动的齿轮动画效果
2014/06/05 HTML / CSS
草莓网化妆品加拿大网站:Strawberrynet Canada
2016/09/20 全球购物
美术教师自我鉴定
2014/02/12 职场文书
上课看小说检讨书
2014/02/22 职场文书
药品营销策划方案
2014/06/15 职场文书
垃圾分类的活动方案
2014/08/15 职场文书
购房公证委托书(2014版)
2014/09/12 职场文书
个人承诺书格式范文
2015/04/29 职场文书
驳回起诉民事裁定书
2015/05/19 职场文书
导游词之珠海轮廓
2019/10/25 职场文书
详解Python小数据池和代码块缓存机制
2021/04/07 Python