tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)


Posted in Python onApril 22, 2020

网上关于tensorflow模型文件ckpt格式转pb文件的帖子很多,本人几乎尝试了所有方法,最后终于成功了,现总结如下。方法无外乎下面两种:

  • 使用tensorflow.python.tools.freeze_graph.freeze_graph
  • 使用graph_util.convert_variables_to_constants

1、tensorflow模型的文件解读

使用tensorflow训练好的模型会自动保存为四个文件,如下

tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)

checkpoint:记录近几次训练好的模型结果(名称)。

xxx.data-00000-of-00001: 模型的所有变量的值(weights, biases, placeholders,gradients, hyper-parameters etc),也就是模型训练好参数和其他值。

xxx.index :模型的元数据,二进制或者其他格式,不可直接查看 。是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和一些辅助数据等。

xxx.meta:模型的meta数据 ,二进制或者其他格式,不可直接查看,保存了TensorFlow计算图的结构信息,通俗地讲就是神经网络的网络结构。

2、最常见的ckpt转pb文件的方法

2、ckpt转pb文件(freeze_graph.freeze_graph)

此种方法尝试成功,虽然不知道输出节点名,但是只要模型代码还在就可以操作,直接上代码。

import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
from model import network # network是你们自己定义的模型结构(代码结构)
# egs:
# def network(input):
# return tf.layers.softmax(input)
 
model_path = "model.ckpt-0000" #设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前
 
def main():
 tf.reset_default_graph()
 # 设置输入网络的数据维度,根据训练时的模型输入数据的维度自行修改
 input_node = tf.placeholder(tf.float32, shape=(None, None, 200)) 
 output_node = network(input_node) # 神经网络的输出
 # 设置输出数据类型(特别注意,这里必须要跟输出网络参数的数据格式保持一致,不然会导致模型预测  精度或者预测能力的丢失)以及重新定义输出节点的名字(这样在后面保存pb文件以及之后使用pb文件时直接使用重新定义的节点名字即可)
 flow = tf.cast(output_node , tf.float16, 'the_outputs') 
 saver = tf.train.Saver()
 with tf.Session() as sess:
 saver.restore(sess, model_path)
 #保存模型图(结构),为一个json文件
 tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model.pb')
 #将模型参数与模型图结合,并保存为pb文件
 freeze_graph.freeze_graph('output_model/pb_model/model.pb', '', False, model_path, 'the_outputs','save/restore_all', 'save/Const:0', 'output_model/pb_model/frozen_model.pb', False, "")
 print("done")
if __name__ == '__main__':
 main()

2、ckpt转pb文件(graph_util.convert_variables_to_constants)

没有成功,因为不知道输出节点的名字,使用该方法保存后的pb文件只有几十k,无法使用,写在这里主要是为了总结。直接上代码,代码里面没有的库(函数),按提示自行import。

def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
  sess=sess,
  input_graph_def=input_graph_def,# 等于:sess.graph_def
  output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
  f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in graph.get_operations():
 # print(op.name, op.values())
 
if __name__ == '__main__':
 # 输入ckpt模型路径
 input_checkpoint='models/model.ckpt-10000'
 # 输出pb模型的路径
 out_pb_path="models/pb/frozen_model.pb"
 # 调用freeze_graph将ckpt转为pb
 freeze_graph(input_checkpoint,out_pb_path)

参考链接:

到此这篇关于tensorflow模型文件(ckpt)转pb文件(不知道输出节点名)的文章就介绍到这了,更多相关tensorflow ckpt转pb文件内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python标准库os.path包、glob包使用实例
Nov 25 Python
详解字典树Trie结构及其Python代码实现
Jun 03 Python
浅谈Python生成器generator之next和send的运行流程(详解)
May 08 Python
使用Python的package机制如何简化utils包设计详解
Dec 11 Python
Python通过matplotlib画双层饼图及环形图简单示例
Dec 15 Python
Python 创建空的list,以及append用法讲解
May 04 Python
matplotlib subplots 设置总图的标题方法
May 25 Python
python一行sql太长折成多行并且有多个参数的方法
Jul 19 Python
Python爬虫框架Scrapy基本用法入门教程
Jul 26 Python
如何在Windows中安装多个python解释器
Jun 16 Python
matplotlib设置颜色、标记、线条,让你的图像更加丰富(推荐)
Sep 25 Python
Python+unittest+requests+excel实现接口自动化测试框架
Dec 23 Python
有趣的Python图片制作之如何用QQ好友头像拼接出里昂
Apr 22 #Python
python模拟斗地主发牌
Apr 22 #Python
matlab 计算灰度图像的一阶矩,二阶矩,三阶矩实例
Apr 22 #Python
python根据完整路径获得盘名/路径名/文件名/文件扩展名的方法
Apr 22 #Python
matlab中二维插值函数interp2的使用详解
Apr 22 #Python
python 一维二维插值实例
Apr 22 #Python
Numpy一维线性插值函数的用法
Apr 22 #Python
You might like
ThinkPHP中redirect用法分析
2014/12/05 PHP
PHP XML和数组互相转换详解
2016/10/26 PHP
php把字符串指定字符分割成数组的方法
2018/03/12 PHP
JS中剪贴板兼容性、判断复制成功或失败
2021/03/09 Javascript
完美解决JS中汉字显示乱码问题(已解决)
2006/12/27 Javascript
location.search在客户端获取Url参数的方法
2010/06/08 Javascript
jQuery插件HighCharts绘制2D柱状图、折线图的组合双轴图效果示例【附demo源码下载】
2017/03/09 Javascript
基于JavaScript实现焦点图轮播效果
2017/03/27 Javascript
前端构建工具之gulp的语法教程
2017/06/12 Javascript
js中json对象和字符串的理解及相互转化操作实现方法
2017/09/22 Javascript
JS高阶函数原理与用法实例分析
2019/01/15 Javascript
vue component 中引入less文件报错 Module build failed
2019/04/17 Javascript
JavaScript变量作用域及内存问题实例分析
2019/06/10 Javascript
vue点击当前路由高亮小案例
2019/09/26 Javascript
浅谈Vue 函数式组件的使用技巧
2020/06/16 Javascript
[00:58]2016年国际邀请赛勇士令状宣传片
2016/06/01 DOTA
pycharm 使用心得(四)显示行号
2014/06/05 Python
python字典基本操作实例分析
2015/07/11 Python
Python中遍历字典过程中更改元素导致异常的解决方法
2016/05/12 Python
Python实现信用卡系统(支持购物、转账、存取钱)
2016/06/24 Python
python安装oracle扩展及数据库连接方法
2017/02/21 Python
Python3中使用PyMongo的方法详解
2017/07/28 Python
Python中super函数的用法
2017/11/17 Python
Django实现组合搜索的方法示例
2018/01/23 Python
python XlsxWriter模块创建aexcel表格的实例讲解
2018/05/03 Python
pycharm恢复默认设置或者是替换pycharm的解释器实例
2018/10/29 Python
python登录WeChat 实现自动回复实例详解
2019/05/28 Python
python线程信号量semaphore使用解析
2019/11/30 Python
Evisu官方网站:日本牛仔品牌,时尚街头设计风格
2016/12/30 全球购物
意大利灯具购物网站:Lampade.it
2018/10/18 全球购物
学历公证书范本
2014/04/09 职场文书
离婚答辩状范文
2015/05/22 职场文书
怒海潜将观后感
2015/06/11 职场文书
网吧温馨提示
2015/07/17 职场文书
为什么 Nginx 比 Apache 更牛逼
2021/03/31 Servers
SQL实现LeetCode(177.第N高薪水)
2021/08/04 MySQL