tensorflow实现将ckpt转pb文件的方法


Posted in Python onApril 22, 2020

   本博客实现将自己训练保存的ckpt模型转换为pb文件,该方法适用于任何ckpt模型,当然你需要确定ckpt模型输入/输出的节点名称。

   使用 tf.train.saver()保存模型时会产生多个文件,会把计算图的结构和图上参数取值分成了不同的文件存储。这种方法是在TensorFlow中是最常用的保存方式。

    例如:下面的代码运行后,会在save目录下保存了四个文件:

import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
 sess.run(init_op)
 print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
 print("v2:", sess.run(v2))
 saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
 print("Model saved in file:", saver_path)

    其中,checkpoint是检查点文件,文件保存了一个目录下所有的模型文件列表;
model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构,该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用。
ckpt.data : 保存模型中每个变量的取值
   但很多时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),方便在其他地方使用(如在Android中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。 我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

    TensoFlow为我们提供了convert_variables_to_constants()方法,该方法可以固化模型结构,将计算图中的变量取值以常量的形式保存,而且保存的模型可以移植到Android平台。

一、CKPT 转换成 PB格式

    将CKPT 转换成 PB格式的文件的过程可简述如下:

通过传入 CKPT 模型的路径得到模型的图和变量数据
通过 import_meta_graph 导入模型中的图
通过 saver.restore 从模型中恢复图中各个变量的数据
通过 graph_util.convert_variables_to_constants 将模型持久化
 下面的CKPT 转换成 PB格式例子,是我训练GoogleNet InceptionV3模型保存的ckpt转pb文件的例子,训练过程可参考博客:《使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型(TensorFlow)》:

def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=input_graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in graph.get_operations():
 # print(op.name, op.values())

说明:

1、函数freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于freeze操作,我们需要定义输出结点的名字。因为网络其实是比较复杂的,定义了输出结点的名字,那么freeze的时候就只把输出该结点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们freeze模型的目的是接下来做预测。所以,output_node_names一般是网络模型最后一层输出的节点名称,或者说就是我们预测的目标。

 2、在保存的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称,对于鄙人的代码,需要固化的节点只有一个:output_node_names。注意节点名称与张量的名称的区别,例如:“input:0”是张量的名称,而"input"表示的是节点的名称。

3、源码中通过graph = tf.get_default_graph()获得默认的图,这个图就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢复的图,因此必须先执行tf.train.import_meta_graph,再执行tf.get_default_graph() 。

4、实质上,我们可以直接在恢复的会话sess中,获得默认的网络图,更简单的方法,如下:

def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=sess.graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点

调用方法很简单,输入ckpt模型路径,输出pb模型的路径即可:

    # 输入ckpt模型路径
    input_checkpoint='models/model.ckpt-10000'
    # 输出pb模型的路径
    out_pb_path="models/pb/frozen_model.pb"
    # 调用freeze_graph将ckpt转为pb
    freeze_graph(input_checkpoint,out_pb_path)

5、上面以及说明:在保存的时候,通过convert_variables_to_constants函数来指定需要固化的节点名称,对于鄙人的代码,需要固化的节点只有一个:output_node_names。因此,其他网络模型,也可以通过简单的修改输出的节点名称output_node_names,将ckpt转为pb文件 。

       PS:注意节点名称,应包含name_scope 和 variable_scope命名空间,并用“/”隔开,如"InceptionV3/Logits/SpatialSqueeze"

二、 pb模型预测

    下面是预测pb模型的代码

def freeze_graph_test(pb_path, image_path):
 '''
 :param pb_path:pb文件的路径
 :param image_path:测试图片的路径
 :return:
 '''
 with tf.Graph().as_default():
 output_graph_def = tf.GraphDef()
 with open(pb_path, "rb") as f:
 output_graph_def.ParseFromString(f.read())
 tf.import_graph_def(output_graph_def, name="")
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 
 # 定义输入的张量名称,对应网络结构的输入张量
 # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
 input_image_tensor = sess.graph.get_tensor_by_name("input:0")
 input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
 input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
 # 定义输出的张量名称
 output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
 # 读取测试图片
 im=read_image(image_path,resize_height,resize_width,normalization=True)
 im=im[np.newaxis,:]
 # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
 # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
 out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
 input_keep_prob_tensor:1.0,
 input_is_training_tensor:False})
 print("out:{}".format(out))
 score = tf.nn.softmax(out, name='pre')
 class_id = tf.argmax(score, 1)
 print "pre class_id:{}".format(sess.run(class_id))

说明:

1、与ckpt预测不同的是,pb文件已经固化了网络模型结构,因此,即使不知道原训练模型(train)的源码,我们也可以恢复网络图,并进行预测。恢复模型十分简单,只需要从读取的序列化数据中导入网络结构即可:

tf.import_graph_def(output_graph_def, name="")
2、但必须知道原网络模型的输入和输出的节点名称(当然了,传递数据时,是通过输入输出的张量来完成的)。由于InceptionV3模型的输入有三个节点,因此这里需要定义输入的张量名称,它对应网络结构的输入张量:

input_image_tensor = sess.graph.get_tensor_by_name("input:0")
input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
以及输出的张量名称:

output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")

3、预测时,需要feed输入数据:

# 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
# out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
                                            input_keep_prob_tensor:1.0,
                                            input_is_training_tensor:False})

4、其他网络模型预测时,也可以通过修改输入和输出的张量的名称 。

       PS:注意张量的名称,即为:节点名称+“:”+“id号”,如"InceptionV3/Logits/SpatialSqueeze:0"

完整的CKPT 转换成 PB格式和预测的代码如下:

# -*-coding: utf-8 -*-
"""
 @Project: tensorflow_models_nets
 @File : convert_pb.py
 @Author : panjq
 @E-mail : pan_jinquan@163.com
 @Date : 2018-08-29 17:46:50
 @info :
 -通过传入 CKPT 模型的路径得到模型的图和变量数据
 -通过 import_meta_graph 导入模型中的图
 -通过 saver.restore 从模型中恢复图中各个变量的数据
 -通过 graph_util.convert_variables_to_constants 将模型持久化
"""
 
import tensorflow as tf
from create_tf_record import *
from tensorflow.python.framework import graph_util
 
resize_height = 299 # 指定图片高度
resize_width = 299 # 指定图片宽度
depths = 3
 
def freeze_graph_test(pb_path, image_path):
 '''
 :param pb_path:pb文件的路径
 :param image_path:测试图片的路径
 :return:
 '''
 with tf.Graph().as_default():
 output_graph_def = tf.GraphDef()
 with open(pb_path, "rb") as f:
 output_graph_def.ParseFromString(f.read())
 tf.import_graph_def(output_graph_def, name="")
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 
 # 定义输入的张量名称,对应网络结构的输入张量
 # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
 input_image_tensor = sess.graph.get_tensor_by_name("input:0")
 input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
 input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
 # 定义输出的张量名称
 output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
 # 读取测试图片
 im=read_image(image_path,resize_height,resize_width,normalization=True)
 im=im[np.newaxis,:]
 # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
 # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
 out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
 input_keep_prob_tensor:1.0,
 input_is_training_tensor:False})
 print("out:{}".format(out))
 score = tf.nn.softmax(out, name='pre')
 class_id = tf.argmax(score, 1)
 print "pre class_id:{}".format(sess.run(class_id))
 
 
def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=sess.graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in sess.graph.get_operations():
 # print(op.name, op.values())
 
def freeze_graph2(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 graph = tf.get_default_graph() # 获得默认的图
 input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) #恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=input_graph_def,# 等于:sess.graph_def
 output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
 
 with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
 f.write(output_graph_def.SerializeToString()) #序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 # for op in graph.get_operations():
 # print(op.name, op.values())
 
 
if __name__ == '__main__':
 # 输入ckpt模型路径
 input_checkpoint='models/model.ckpt-10000'
 # 输出pb模型的路径
 out_pb_path="models/pb/frozen_model.pb"
 # 调用freeze_graph将ckpt转为pb
 freeze_graph(input_checkpoint,out_pb_path)
 
 # 测试pb模型
 image_path = 'test_image/animal.jpg'
 freeze_graph_test(pb_path=out_pb_path, image_path=image_path)

三、源码下载和资料推荐

    1、训练方法
     上面的CKPT 转换成 PB格式例子,是我训练GoogleNet InceptionV3模型保存的ckpt转pb文件的例子,训练过程可参考博客:

《使用自己的数据集训练GoogLenet InceptionNet V1 V2 V3模型(TensorFlow)》:https://blog.csdn.net/guyuealian/article/details/81560537

    2、Github地址
Github源码:https://github.com/PanJinquan/tensorflow_models_nets  中的convert_pb.py文件

预训练模型下载地址:http://xiazai.3water.com/202004/yuanma/googlenet_inception_3water.rar

    3、将模型移植Android的方法
     pb文件是可以移植到Android平台运行的,其方法,可参考:

《将tensorflow训练好的模型移植到Android (MNIST手写数字识别)》

参考:

[1] https://3water.com/article/185209.htm

【2】https://3water.com/article/185206.htm

到此这篇关于tensorflow实现将ckpt转pb文件的方法的文章就介绍到这了,更多相关tensorflow ckpt转pb文件内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python压缩和解压缩zip文件
Feb 14 Python
12步入门Python中的decorator装饰器使用方法
Jun 20 Python
TensorFlow实现简单的CNN的方法
Jul 18 Python
详解用python计算阶乘的几种方法
Aug 14 Python
Python字符编码转码之GBK,UTF8互转
Feb 09 Python
pymysql 插入数据 转义处理方式
Mar 02 Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 Python
Python实现GIF图倒放
Jul 16 Python
Python self用法详解
Nov 28 Python
利用Python过滤相似文本的简单方法示例
Feb 03 Python
Python爬虫之爬取某文库文档数据
Apr 21 Python
端午节将至,用Python爬取粽子数据并可视化,看看网友喜欢哪种粽子吧!
Jun 11 Python
jupyter lab文件导出/下载方式
Apr 22 #Python
python模拟实现分发扑克牌
Apr 22 #Python
tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)
Apr 22 #Python
有趣的Python图片制作之如何用QQ好友头像拼接出里昂
Apr 22 #Python
python模拟斗地主发牌
Apr 22 #Python
matlab 计算灰度图像的一阶矩,二阶矩,三阶矩实例
Apr 22 #Python
python根据完整路径获得盘名/路径名/文件名/文件扩展名的方法
Apr 22 #Python
You might like
判断是否为指定长度内字符串的php函数
2010/02/16 PHP
PHP的Yii框架中移除组件所绑定的行为的方法
2016/03/18 PHP
PHP数据库处理封装类实例
2016/12/24 PHP
php微信支付之公众号支付功能
2018/05/30 PHP
原生PHP实现导出csv格式Excel文件的方法示例【附源码下载】
2019/03/07 PHP
js中判断控件是否存在
2010/08/25 Javascript
javascript中"/"运算符常见错误
2010/10/13 Javascript
用JavaScript获取DOM元素位置和尺寸大小的方法
2013/04/12 Javascript
页面按钮禁用与解除禁用的方法
2014/02/19 Javascript
JavaScript弹出窗口方法汇总
2014/08/12 Javascript
location.hash保存页面状态的技巧
2016/04/28 Javascript
利用JQuery直接调用asp.net后台的简单方法
2016/10/27 Javascript
详解HTTPS 的原理和 NodeJS 的实现
2017/07/04 NodeJs
详解vue指令与$nextTick 操作DOM的不同之处
2018/08/02 Javascript
使用 Node.js 实现图片的动态裁切及算法实例代码详解
2018/09/29 Javascript
详解vue几种主动刷新的方法总结
2019/02/19 Javascript
JavaScrip如果基于url实现图片下载
2020/07/03 Javascript
JS实现小米轮播图
2020/09/21 Javascript
linux下python抓屏实现方法
2015/05/22 Python
利用PyInstaller将python程序.py转为.exe的方法详解
2017/05/03 Python
python使用socket创建tcp服务器和客户端
2018/04/12 Python
Python实现计算圆周率π的值到任意位的方法示例
2018/05/08 Python
python实现远程控制电脑
2019/05/23 Python
Python 旋转打印各种矩形的方法
2019/07/09 Python
Django实现基于类的分页功能
2019/10/31 Python
python实现飞机大战游戏(pygame版)
2020/10/26 Python
使用Django xadmin 实现修改时间选择器为不可输入状态
2020/03/30 Python
PyCharm设置Ipython交互环境和宏快捷键进行数据分析图文详解
2020/04/23 Python
Django多层嵌套ManyToMany字段ORM操作详解
2020/05/19 Python
CSS3 圆角效果
2009/07/15 HTML / CSS
教师校本培训方案
2014/02/26 职场文书
批评与自我批评范文
2014/10/15 职场文书
2014年保险业务员工作总结
2014/12/23 职场文书
毕业纪念册寄语大全
2015/02/26 职场文书
采购部年度工作总结
2015/08/13 职场文书
《雷雨》教学反思
2016/02/20 职场文书