tensorflow的ckpt及pb模型持久化方式及转化详解


Posted in Python onFebruary 12, 2020

使用tensorflow训练模型的时候,模型持久化对我们来说非常重要。

如果我们的模型比较复杂,需要的数据比较多,那么在模型的训练时间会耗时很长。如果在训练过程中出现了模型不可预期的错误,导致训练意外终止,那么我们将会前功尽弃。为了解决这一问题,我们可以使用模型持久化(保存为ckpt文件格式)来保存我们在训练过程中的临时数据。、

如果我们训练出的模型需要提供给用户做离线预测,那么我们只需要完成前向传播过程。这个时候我们就可以使用模型持久化(保存为pb文件格式)来只保存前向传播过程中的变量并将变量固定下来,这时候用户只需要提供一个输入即可得到前向传播的预测结果。

ckpt和pb持久化方式的区别在于ckpt文件将模型结构与模型权重分离保存,便于训练过程;pb文件则是graph_def的序列化文件,便于发布和离线预测。官方提供freeze_grpah.py脚本来将ckpt文件转为pb文件。

CKPT模型持久化

首先定义前向传播过程;

声明并得到一个Saver;

使用Saver.save()保存模型;

# coding=UTF-8 支持中文编码格式
import tensorflow as tf
import shutil
import os.path
 
MODEL_DIR = "/home/zheng/PycharmProjects/ckptLoad/Models/"
MODEL_NAME = "model.ckpt"
 
#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder") #输入占位符,并指定名字,后续模型读取可能会用的
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
predictions = tf.add(_y, 50, name="predictions") #输出节点名字,后续模型读取会用到,比50大返回true,否则返回false
 
init = tf.global_variables_initializer()
saver = tf.train.Saver() #声明saver用于保存模型
 
with tf.Session() as sess:
 sess.run(init)
 print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]}) #输入一个数据测试一下
 saver.save(sess, os.path.join(MODEL_DIR, MODEL_NAME)) #模型保存
 print("%d ops in the final graph." % len(tf.get_default_graph().as_graph_def().node)) #得到当前图有几个操作节点

predictions : [ 101.]
28 ops in the final graph.

注:代码含义请参考注释,需要注意的是可以自定义模型保存的路径

ckpt模型持久化使用起来非常简单,只需要我们声明一个tf.train.Saver,然后调用save()函数,将会话模型保存到指定的目录。执行代码结果,会在我们指定模型目录下出现4个文件

tensorflow的ckpt及pb模型持久化方式及转化详解

checkpoint : 记录目录下所有模型文件列表
ckpt.data : 保存模型中每个变量的取值
ckpt.meta : 保存整个计算图的结构

ckpt模型加载

# -*- coding: utf-8 -*-)
import tensorflow as tf
from numpy.random import RandomState
 
# 定义训练数据batch的大小
batch_size = 8
 
#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder") #输入占位符,并指定名字,后续模型读取可能会用的
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
predictions = tf.add(_y, 50, name="predictions") #输出节点名字,后续模型读取会用到,比50大返回true,否则返回false
 
#saver=tf.train.Saver()
# creare a session,创建一个会话来运行TensorFlow程序
with tf.Session() as sess:
 
 saver = tf.train.import_meta_graph('/home/zheng/Models/model/model.meta')
 saver.restore(sess, tf.train.latest_checkpoint('/home/zheng/Models/model'))
 #saver.restore(sess, tf.train.latest_checkpoint('/home/zheng/Models/model'))
 # 初始化变量
 sess.run(tf.global_variables_initializer())
 print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]})

代码结果,可以看到运行结果一样

predictions : [ 101.]

PB模型持久化

定义运算过程

通过 get_default_graph().as_graph_def() 得到当前图的计算节点信息

通过 graph_util.convert_variables_to_constants 将相关节点的values固定

通过 tf.gfile.GFile 进行模型持久化

# coding=UTF-8
import tensorflow as tf
import shutil
import os.path
from tensorflow.python.framework import graph_util
 
MODEL_DIR = "/home/zheng/PycharmProjects/pbLoad/Models/"
MODEL_NAME = "model"
 
 
#output_graph = "model/pb/add_model.pb"
 
#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式
input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")
W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")
B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")
_y = (input_holder * W1) + B1
predictions = tf.add(_y, 50, name="predictions")
init = tf.global_variables_initializer()
 
with tf.Session() as sess:
 sess.run(init)
 print "predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]})
 graph_def = tf.get_default_graph().as_graph_def() #得到当前的图的 GraphDef 部分,
              #通过这个部分就可以完成重输入层到
              #输出层的计算过程
 
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
  sess,
  graph_def,
  ["predictions"] #需要保存节点的名字
 )
 with tf.gfile.GFile(os.path.join(MODEL_DIR,MODEL_NAME), "wb") as f: # 保存模型
  f.write(output_graph_def.SerializeToString()) # 序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node))
 print (predictions)
 
# for op in tf.get_default_graph().get_operations(): 打印模型节点信息
#  print (op.name)

结果输出

predictions : [ 101.]
Converted 2 variables to const ops.
9 ops in the final graph.
Tensor("predictions:0", shape=(1,), dtype=float32)

tensorflow的ckpt及pb模型持久化方式及转化详解

并在指定目录下生成pb文件模型,保存了从输入层到输出层这个计算过程的计算图和相关变量的值,我们得到这个模型后传入一个输入,既可以得到一个预估的输出值

pb模型文件加载

# -*- coding: utf-8 -*-)
from tensorflow.python.platform import gfile
import tensorflow as tf
from numpy.random import RandomState
 
sess = tf.Session()
with gfile.FastGFile('./Models/model', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') # 导入计算图
 
# 需要有一个初始化的过程
sess.run(tf.global_variables_initializer())
# 需要先复原变量
sess.run('W1:0')
sess.run('B1:0')
# 输入
input_x = sess.graph.get_tensor_by_name('input_holder:0')
#input_y = sess.graph.get_tensor_by_name('y-input:0')
op = sess.graph.get_tensor_by_name('predictions:0')
ret = sess.run(op, feed_dict={input_x:[10]})
print(ret)

输出结果

[ 101.]

我们可以看到结果一致。

ckpt格式转pb格式

通过传入 CKPT 模型的路径得到模型的图和变量数据
通过 import_meta_graph 导入模型中的图
通过 saver.restore 从模型中恢复图中各个变量的数据
通过 graph_util.convert_variables_to_constants 将模型持久化

# coding=UTF-8
import tensorflow as tf
import os.path
import argparse
from tensorflow.python.framework import graph_util
 
MODEL_DIR = "/home/zheng/PycharmProjects/ckptToPb/model/"
MODEL_NAME = "frozen_model"
 
def freeze_graph(model_folder):
 checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
 input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
 output_graph = os.path.join(MODEL_DIR, MODEL_NAME) #PB模型保存路径
 
 output_node_names = "predictions" #原模型输出操作节点的名字
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) #得到图、clear_devices :Whether or not to clear the device field for an `Operation` or `Tensor` during import.
 
 graph = tf.get_default_graph() #获得默认的图
 input_graph_def = graph.as_graph_def() #返回一个序列化的图代表当前的图
 
 with tf.Session() as sess:
  saver.restore(sess, input_checkpoint) #恢复图并得到数据
 
  print "predictions : ", sess.run("predictions:0", feed_dict={"input_holder:0": [10.0]}) # 测试读出来的模型是否正确,注意这里传入的是输出 和输入 节点的 tensor的名字,不是操作节点的名字
 
  output_graph_def = graph_util.convert_variables_to_constants( #模型持久化,将变量值固定
   sess,
   input_graph_def,
   output_node_names.split(",") #如果有多个输出节点,以逗号隔开
  )
  with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
   f.write(output_graph_def.SerializeToString()) #序列化输出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
 
 
if __name__ == '__main__':
 #parser = argparse.ArgumentParser()
 #parser.add_argument("model_folder", type=str, help="input ckpt model dir") #命令行解析,help是提示符,type是输入的类型,
 # 这里运行程序时需要带上模型ckpt的路径,不然会报 error: too few arguments
 #aggs = parser.parse_args()
 #freeze_graph(aggs.model_folder)
 freeze_graph("/home/zheng/PycharmProjects/ckptLoad/Models/") #模型目录

注意改变ckpt模型目录及pb文件保存目录 。

tensorflow的ckpt及pb模型持久化方式及转化详解

运行结果为

predictions : [ 101.]
Converted 2 variables to const ops.
9 ops in the final graph.

总结:cpkt文件格式将模型保存为4个文件,pb文件格式为一个。ckpt模型持久化方式将图结构与权重参数分开保存,多了模型更多的细节,适合模型训练阶段;而pb持久化方式完成了从输入到输出的前向传播,完成了端到端的形式,更是个离线使用。

以上这篇tensorflow的ckpt及pb模型持久化方式及转化详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python抓取京东图书评论数据
Aug 31 Python
谈谈如何手动释放Python的内存
Dec 17 Python
一个月入门Python爬虫学习,轻松爬取大规模数据
Jan 03 Python
python时间日期函数与利用pandas进行时间序列处理详解
Mar 13 Python
python实现学生信息管理系统
Apr 05 Python
Python调用服务接口的实例
Jan 03 Python
python爬取内容存入Excel实例
Feb 20 Python
使用Windows批处理和WMI设置Python的环境变量方法
Aug 14 Python
Django之使用内置函数和celery发邮件的方法示例
Sep 16 Python
python爬虫模块URL管理器模块用法解析
Feb 03 Python
python rolling regression. 使用 Python 实现滚动回归操作
Jun 08 Python
python 通过exifread读取照片信息
Dec 24 Python
关于Tensorflow 模型持久化详解
Feb 12 #Python
Python qrcode 生成一个二维码的实例详解
Feb 12 #Python
python标准库sys和OS的函数使用方法与实例详解
Feb 12 #Python
python标准库os库的函数介绍
Feb 12 #Python
Tensorflow 1.0之后模型文件、权重数值的读取方式
Feb 12 #Python
Python django框架开发发布会签到系统(web开发)
Feb 12 #Python
Python计算公交发车时间的完整代码
Feb 12 #Python
You might like
ThinkPHP的RBAC(基于角色权限控制)深入解析
2013/06/17 PHP
解析如何通过PHP函数获取当前运行的环境 来进行判断执行逻辑(小技巧)
2013/06/25 PHP
php脚本守护进程原理与实现方法详解
2017/07/20 PHP
Laravel5.5以下版本中如何自定义日志行为详解
2018/08/01 PHP
Apache站点配置SSL强制跳转443
2021/03/09 Servers
JQuery入门—编写一个简单的JQuery应用案例
2013/01/03 Javascript
.net,js捕捉文本框回车键事件的小例子(兼容多浏览器)
2013/03/11 Javascript
js如何调用qq互联api实现第三方登录
2014/03/28 Javascript
浅谈javascript对象模型和function对象
2014/12/26 Javascript
Express与NodeJs创建服务器的两种方法
2017/02/06 NodeJs
微信小程序 检查接口状态实例详解
2017/06/23 Javascript
JS库中的Particles.js在vue上的运用案例分析
2017/09/13 Javascript
Echarts之悬浮框中的数据排序问题
2018/11/08 Javascript
[04:13]2014DOTA2国际邀请赛 专访DC目前形势不容乐观
2014/07/12 DOTA
跟老齐学Python之使用Python查询更新数据库
2014/11/25 Python
在Python中实现贪婪排名算法的教程
2015/04/17 Python
pandas系列之DataFrame 行列数据筛选实例
2018/04/12 Python
python 获取一个值在某个区间的指定倍数的值方法
2018/11/12 Python
Python绘图实现显示中文
2019/12/04 Python
如何将PySpark导入Python的放实现(2种)
2020/04/26 Python
Python Dict找出value大于某值或key大于某值的所有项方式
2020/06/05 Python
使用Python爬取Json数据的示例代码
2020/12/07 Python
10分钟入门CSS3 Animation
2018/12/25 HTML / CSS
Android本地应用打开方法——通过html5写连接
2016/03/11 HTML / CSS
Anthropologie英国:美国家喻户晓的休闲服装和家居产品品牌
2018/12/05 全球购物
古驰英国官网:GUCCI英国
2020/03/07 全球购物
武汉某公司的C#笔试题面试题
2015/12/25 面试题
办公室文书岗位职责
2013/12/16 职场文书
入学申请自荐信范文
2014/02/26 职场文书
遗体告别仪式主持词
2014/03/20 职场文书
小学生运动会通讯稿
2014/09/23 职场文书
信用卡工资证明范本
2014/10/17 职场文书
教师纪律作风整顿心得体会
2016/01/23 职场文书
python基础入门之字典和集合
2021/06/13 Python
vue elementUI表格控制对应列
2022/04/13 Vue.js
索尼ICF-36收音机评测
2022/04/30 无线电