TensorFlow模型保存和提取的方法


Posted in Python onMarch 08, 2018

一、TensorFlow模型保存和提取方法

1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,saver.save(sess,"Model/model.ckpt") ,实际在这个文件目录下会生成4个人文件:

TensorFlow模型保存和提取的方法

checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model.ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同,但加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。

2. 加载这个已保存的TensorFlow模型的方法是saver.restore(sess,"./Model/model.ckpt") ,加载模型的代码中也要定义TensorFlow计算图上的所有运算并声明一个tf.train.Saver类,不同的是加载模型时不需要进行变量的初始化,而是将变量的取值通过保存的模型加载进来,注意加载路径的写法。若不希望重复定义计算图上的运算,可直接加载已经持久化的图,saver =tf.train.import_meta_graph("Model/model.ckpt.meta")

3.tf.train.Saver类也支持在保存和加载时给变量重命名,声明Saver类对象的时候使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名},saver = tf.train.Saver({"v1":u1, "v2": u2})即原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中。

4. 上一条做的目的之一就是方便使用变量的滑动平均值。如果在加载模型时直接将影子变量映射到变量自身,则在使用训练好的模型时就不需要再调用函数来获取变量的滑动平均值了。载入时,声明Saver类对象时通过一个字典将滑动平均值直接加载到新的变量中,saver = tf.train.Saver({"v/ExponentialMovingAverage": v}),另通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典。

此外,通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中。

二、TensorFlow程序实现

# 本文件程序为配合教材及学习进度渐进进行,请按照注释分段执行 
# 执行时要注意IDE的当前工作过路径,最好每段重启控制器一次,输出结果更准确 
 
 
# Part1: 通过tf.train.Saver类实现保存和载入神经网络模型 
 
# 执行本段程序时注意当前的工作路径 
import tensorflow as tf 
 
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") 
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") 
result = v1 + v2 
 
saver = tf.train.Saver() 
 
with tf.Session() as sess: 
 sess.run(tf.global_variables_initializer()) 
 saver.save(sess, "Model/model.ckpt") 
 
 
# Part2: 加载TensorFlow模型的方法 
 
import tensorflow as tf 
 
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") 
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") 
result = v1 + v2 
 
saver = tf.train.Saver() 
 
with tf.Session() as sess: 
 saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./" 
 print(sess.run(result)) # [ 3.] 
 
 
# Part3: 若不希望重复定义计算图上的运算,可直接加载已经持久化的图 
 
import tensorflow as tf 
 
saver = tf.train.import_meta_graph("Model/model.ckpt.meta") 
 
with tf.Session() as sess: 
 saver.restore(sess, "./Model/model.ckpt") # 注意路径写法 
 print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [ 3.] 
 
 
# Part4: tf.train.Saver类也支持在保存和加载时给变量重命名 
 
import tensorflow as tf 
 
# 声明的变量名称name与已保存的模型中的变量名称name不一致 
u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1") 
u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2") 
result = u1 + u2 
 
# 若直接生命Saver类对象,会报错变量找不到 
# 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名} 
# 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中 
saver = tf.train.Saver({"v1": u1, "v2": u2}) 
 
with tf.Session() as sess: 
 saver.restore(sess, "./Model/model.ckpt") 
 print(sess.run(result)) # [ 3.] 
 
 
# Part5: 保存滑动平均模型 
 
import tensorflow as tf 
 
v = tf.Variable(0, dtype=tf.float32, name="v") 
for variables in tf.global_variables(): 
 print(variables.name) # v:0 
 
ema = tf.train.ExponentialMovingAverage(0.99) 
maintain_averages_op = ema.apply(tf.global_variables()) 
for variables in tf.global_variables(): 
 print(variables.name) # v:0 
       # v/ExponentialMovingAverage:0 
 
saver = tf.train.Saver() 
 
with tf.Session() as sess: 
 sess.run(tf.global_variables_initializer()) 
 sess.run(tf.assign(v, 10)) 
 sess.run(maintain_averages_op) 
 saver.save(sess, "Model/model_ema.ckpt") 
 print(sess.run([v, ema.average(v)])) # [10.0, 0.099999905] 
 
 
# Part6: 通过变量重命名直接读取变量的滑动平均值 
 
import tensorflow as tf 
 
v = tf.Variable(0, dtype=tf.float32, name="v") 
saver = tf.train.Saver({"v/ExponentialMovingAverage": v}) 
 
with tf.Session() as sess: 
 saver.restore(sess, "./Model/model_ema.ckpt") 
 print(sess.run(v)) # 0.0999999 
 
 
# Part7: 通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典 
 
import tensorflow as tf 
 
v = tf.Variable(0, dtype=tf.float32, name="v") 
# 注意此处的变量名称name一定要与已保存的变量名称一致 
ema = tf.train.ExponentialMovingAverage(0.99) 
print(ema.variables_to_restore()) 
# {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} 
# 此处的v取自上面变量v的名称name="v" 
 
saver = tf.train.Saver(ema.variables_to_restore()) 
 
with tf.Session() as sess: 
 saver.restore(sess, "./Model/model_ema.ckpt") 
 print(sess.run(v)) # 0.0999999 
 
 
# Part8: 通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中 
 
import tensorflow as tf 
from tensorflow.python.framework import graph_util 
 
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") 
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") 
result = v1 + v2 
 
with tf.Session() as sess: 
 sess.run(tf.global_variables_initializer()) 
 # 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分 
 graph_def = tf.get_default_graph().as_graph_def() 
 output_graph_def = graph_util.convert_variables_to_constants(sess, 
              graph_def, ['add']) 
 
 with tf.gfile.GFile("Model/combined_model.pb", 'wb') as f: 
  f.write(output_graph_def.SerializeToString()) 
 
 
# Part9: 载入包含变量及其取值的模型 
 
import tensorflow as tf 
from tensorflow.python.platform import gfile 
 
with tf.Session() as sess: 
 model_filename = "Model/combined_model.pb" 
 with gfile.FastGFile(model_filename, 'rb') as f: 
  graph_def = tf.GraphDef() 
  graph_def.ParseFromString(f.read()) 
 
 result = tf.import_graph_def(graph_def, return_elements=["add:0"]) 
 print(sess.run(result)) # [array([ 3.], dtype=float32)]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python查询Mysql时返回字典结构的代码
Jun 18 Python
Python自动化测试工具Splinter简介和使用实例
May 13 Python
请不要重复犯我在学习Python和Linux系统上的错误
Dec 12 Python
Python之os操作方法(详解)
Jun 15 Python
Python 查找list中的某个元素的所有的下标方法
Jun 27 Python
python使用pygame框架实现推箱子游戏
Nov 20 Python
python selenium 弹出框处理的实现
Feb 26 Python
tensorflow中tf.slice和tf.gather切片函数的使用
Jan 19 Python
Python tkinter模版代码实例
Feb 05 Python
Python终端输出彩色字符方法详解
Feb 11 Python
信号生成及DFT的python实现方式
Feb 25 Python
简述python&amp;pytorch 随机种子的实现
Oct 07 Python
火车票抢票python代码公开揭秘!
Mar 08 #Python
Python实现定时备份mysql数据库并把备份数据库邮件发送
Mar 08 #Python
python实现12306抢票及自动邮件发送提醒付款功能
Mar 08 #Python
TensorFlow模型保存/载入的两种方法
Mar 08 #Python
python2.7 json 转换日期的处理的示例
Mar 07 #Python
教你用Python创建微信聊天机器人
Mar 31 #Python
为什么入门大数据选择Python而不是Java?
Mar 07 #Python
You might like
PHP通过COM使用ADODB的简单例子
2006/12/31 PHP
php使用pclzip类实现文件压缩的方法(附pclzip类下载地址)
2016/04/30 PHP
动态加载图片路径 保持JavaScript控件的相对独立性
2010/09/03 Javascript
基于jQuery的Tab选项框效果代码(插件)
2011/03/01 Javascript
推荐10个超棒的jQuery工具提示插件
2011/10/11 Javascript
自己用jQuery写了一个图片的马赛克消失效果
2014/05/04 Javascript
5款JavaScript代码压缩工具推荐
2014/07/07 Javascript
jQuery scrollFix滚动定位插件
2015/04/01 Javascript
用JavaScript显示浏览器客户端信息的超相近教程
2015/06/18 Javascript
jquery无限级联下拉菜单简单实例演示
2015/11/23 Javascript
Html5+jQuery+CSS制作相册小记录
2016/12/30 Javascript
vuejs父子组件通信的问题
2017/01/11 Javascript
jQuery tip提示插件(实例分享)
2017/04/28 jQuery
VueJs组件prop验证简单介绍
2017/09/12 Javascript
Vue.js 点击按钮显示/隐藏内容的实例代码
2018/02/08 Javascript
JS实现将链接生成二维码并转为图片的方法
2018/03/17 Javascript
VUE实现可随意拖动的弹窗组件
2018/09/25 Javascript
Vue自定义组件的四种方式示例详解
2020/02/28 Javascript
python设置windows桌面壁纸的实现代码
2013/01/28 Python
使用Python的PEAK来适配协议的教程
2015/04/14 Python
Python函数可变参数定义及其参数传递方式实例详解
2015/05/25 Python
python 读取竖线分隔符的文本方法
2018/12/20 Python
python将字典列表导出为Excel文件的方法
2019/09/02 Python
python词云库wordcloud的使用方法与实例详解
2020/02/17 Python
DRF使用simple JWT身份验证的实现
2021/01/14 Python
Ray-Ban雷朋美国官网:全球领先的太阳眼镜品牌
2016/07/20 全球购物
美国著名的女性内衣零售商:Frederick’s of Hollywood
2018/02/24 全球购物
英国领先的电视购物零售商:Ideal World
2019/03/18 全球购物
路政管理专业个人自荐信范文
2013/11/30 职场文书
意向书范文
2014/03/31 职场文书
代办委托书怎样写
2014/04/08 职场文书
道路交通事故人身损害赔偿协议书
2014/11/19 职场文书
《火烧云》教学反思
2016/02/23 职场文书
PHP命令行与定时任务
2021/04/01 PHP
CentOS 7安装mysql5.7使用XtraBackUp备份工具命令详解
2022/04/12 MySQL
利用Redis实现点赞功能的示例代码
2022/06/28 Redis