TensorFlow 模型载入方法汇总(小结)


Posted in Python onJune 19, 2018

一、TensorFlow常规模型加载方法

保存模型

tf.train.Saver()类,.save(sess, ckpt文件目录)方法

参数名称 功能说明 默认值
var_list Saver中存储变量集合 全局变量集合
reshape 加载时是否恢复变量形状 True
sharded 是否将变量轮循放在所有设备上 True
max_to_keep 保留最近检查点个数 5
restore_sequentially 是否按顺序恢复变量,模型较大时顺序恢复内存消耗小 True

var_list是字典形式{变量名字符串: 变量符号},相对应的restore也根据同样形式的字典将ckpt中的字符串对应的变量加载给程序中的符号。

如果Saver给定了字典作为加载方式,则按照字典来,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否则每个变量寻找自己的name属性在ckpt中的对应值进行加载。

加载模型

当我们基于checkpoint文件(ckpt)加载参数时,实际上我们使用Saver.restore取代了initializer的初始化

TensorFlow 模型载入方法汇总(小结)

checkpoint文件会记录保存信息,通过它可以定位最新保存的模型:

ckpt = tf.train.get_checkpoint_state('./model/')
print(ckpt.model_checkpoint_path)

TensorFlow 模型载入方法汇总(小结) 

.meta文件保存了当前图结构

.index文件保存了当前参数名

.data文件保存了当前参数值

tf.train.import_meta_graph函数给出model.ckpt-n.meta的路径后会加载图结构,并返回saver对象

ckpt = tf.train.get_checkpoint_state('./model/')

tf.train.Saver函数会返回加载默认图的saver对象,saver对象初始化时可以指定变量映射方式,根据名字映射变量(『TensorFlow』滑动平均)

saver = tf.train.Saver({"v/ExponentialMovingAverage":v})

saver.restore函数给出model.ckpt-n的路径后会自动寻找参数名-值文件进行加载

saver.restore(sess,'./model/model.ckpt-0')
saver.restore(sess,ckpt.model_checkpoint_path)

1.不加载图结构,只加载参数

由于实际上我们参数保存的都是Variable变量的值,所以其他的参数值(例如batch_size)等,我们在restore时可能希望修改,但是图结构在train时一般就已经确定了,所以我们可以使用tf.Graph().as_default()新建一个默认图(建议使用上下文环境),利用这个新图修改和变量无关的参值大小,从而达到目的。

'''
使用原网络保存的模型加载到自己重新定义的图上
可以使用python变量名加载模型,也可以使用节点名
'''
import AlexNet as Net
import AlexNet_train as train
import random
import tensorflow as tf
 
IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'
 
with tf.Graph().as_default() as g:
 
 x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])
 y = Net.inference_1(x, N_CLASS=5, train=False)
 
 with tf.Session() as sess:
  # 程序前面得有 Variable 供 save or restore 才不报错
  # 否则会提示没有可保存的变量
  saver = tf.train.Saver()
 
  ckpt = tf.train.get_checkpoint_state('./model/')
  img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
  img = sess.run(tf.expand_dims(tf.image.resize_images(
   tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))
 
  if ckpt and ckpt.model_checkpoint_path:
   print(ckpt.model_checkpoint_path)
   saver.restore(sess,'./model/model.ckpt-0')
   global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
   res = sess.run(y, feed_dict={x: img})
   print(global_step,sess.run(tf.argmax(res,1)))

2.加载图结构和参数

'''
直接使用使用保存好的图
无需加载python定义的结构,直接使用节点名称加载模型
由于节点形状已经定下来了,所以有不便之处,placeholder定义batch后单张传会报错
现阶段不推荐使用,以后如果理解深入了可能会找到使用方法
'''
import AlexNet_train as train
import random
import tensorflow as tf
 
IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'
 
 
ckpt = tf.train.get_checkpoint_state('./model/')       # 通过检查点文件锁定最新的模型
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') # 载入图结构,保存在.meta文件中
 
with tf.Session() as sess:
 saver.restore(sess,ckpt.model_checkpoint_path)      # 载入参数,参数保存在两个文件中,不过restore会自己寻找
 
 img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
 img = sess.run(tf.image.resize_images(
  tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)))
 imgs = []
 for i in range(128):
  imgs.append(img)
 print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0'),feed_dict={'Placeholder:0': imgs}))
 
 '''
 img = sess.run(tf.expand_dims(tf.image.resize_images(
  tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0))
 print(img)
 imgs = []
 for i in range(128):
  imgs.append(img)
 print(sess.run(tf.get_default_graph().get_tensor_by_name('conv1:0'),
     feed_dict={'Placeholder:0':img}))

注意,在所有两种方式中都可以通过调用节点名称使用节点输出张量,节点.name属性返回节点名称。

3.简化版本

# 连同图结构一同加载
ckpt = tf.train.get_checkpoint_state('./model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
with tf.Session() as sess:
 saver.restore(sess,ckpt.model_checkpoint_path)
    
# 只加载数据,不加载图结构,可以在新图中改变batch_size等的值
# 不过需要注意,Saver对象实例化之前需要定义好新的图结构,否则会报错
saver = tf.train.Saver()
with tf.Session() as sess:
 ckpt = tf.train.get_checkpoint_state('./model/')
 saver.restore(sess,ckpt.model_checkpoint_path)

二、TensorFlow二进制模型加载方法

这种加载方法一般是对应网上各大公司已经训练好的网络模型进行修改的工作

# 新建空白图
self.graph = tf.Graph()
# 空白图列为默认图
with self.graph.as_default():
 # 二进制读取模型文件
 with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
  # 新建GraphDef文件,用于临时载入模型中的图
  graph_def = tf.GraphDef()
  # GraphDef加载模型中的图
  graph_def.ParseFromString(f.read())
  # 在空白图中加载GraphDef中的图
  tf.import_graph_def(graph_def,name='')
  # 在图中获取张量需要使用graph.get_tensor_by_name加张量名
  # 这里的张量可以直接用于session的run方法求值了
  # 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
  self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
  self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name in self.layer_operation_names]

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于Django模板中的数字自增(详解)
Sep 05 Python
python监控linux内存并写入mongodb(推荐)
Sep 11 Python
Python八大常见排序算法定义、实现及时间消耗效率分析
Apr 27 Python
Python Pywavelet 小波阈值实例
Jan 09 Python
Python除法之传统除法、Floor除法及真除法实例详解
May 23 Python
Python面向对象之Web静态服务器
Sep 03 Python
django实现类似触发器的功能
Nov 15 Python
python实现矩阵和array数组之间的转换
Nov 29 Python
new_zeros() pytorch版本的转换方式
Feb 18 Python
新手入门学习python Numpy基础操作
Mar 02 Python
mac 上配置Pycharm连接远程服务器并实现使用远程服务器Python解释器的方法
Mar 19 Python
Python实现LR1文法的完整实例代码
Oct 25 Python
python3爬虫之设计签名小程序
Jun 19 #Python
Python GUI Tkinter简单实现个性签名设计
Jun 19 #Python
TensorFlow数据输入的方法示例
Jun 19 #Python
深入分析python中整型不会溢出问题
Jun 18 #Python
Python登录注册验证功能实现
Jun 18 #Python
详解python3中zipfile模块用法
Jun 18 #Python
python爬取个性签名的方法
Jun 17 #Python
You might like
一个可查询所有表的“通用”查询分页类
2006/10/09 PHP
php方法调用模式与函数调用模式简例
2011/09/20 PHP
PHP表单提交表单名称含有点号(.)则会被转化为下划线(_)
2011/12/14 PHP
Yii2第三方类库插件Imagine的安装和使用
2017/07/06 PHP
laravel 5异常错误:FatalErrorException in Handler.php line 38的解决
2017/10/12 PHP
在JavaScript并非所有的一切都是对象
2013/04/11 Javascript
基于jquery实现的图片在各种分辨率下未知的容器内上下左右居中
2014/05/11 Javascript
Javascript正则控制文本框只能输入整数或浮点数
2014/09/02 Javascript
PHP配置文件php.ini中打开错误报告的设置方法
2015/01/09 PHP
完美实现仿QQ空间评论回复特效
2015/05/06 Javascript
jQuery解决input超多的表单提交
2015/08/10 Javascript
jQuery鼠标事件总结
2016/10/13 Javascript
JavaScript仿支付宝6位数字密码输入框
2016/12/29 Javascript
JS二叉树的简单实现方法示例
2017/04/05 Javascript
React Native预设占位placeholder的使用
2017/09/28 Javascript
vue实现element-ui对话框可拖拽功能
2018/08/17 Javascript
elementUI 设置input的只读或禁用的方法
2018/10/30 Javascript
微信小程序与webview交互实现支付功能
2019/06/07 Javascript
[03:38]2014DOTA2西雅图国际邀请赛 VG战队巡礼
2014/07/07 DOTA
Python实现模拟时钟代码推荐
2015/11/08 Python
python使用paramiko实现远程拷贝文件的方法
2016/04/18 Python
浅谈python中的实例方法、类方法和静态方法
2017/02/17 Python
pandas.dataframe中根据条件获取元素所在的位置方法(索引)
2018/06/07 Python
python利用thrift服务读取hbase数据的方法
2018/12/27 Python
python爬虫 爬取58同城上所有城市的租房信息详解
2019/07/30 Python
python爬虫 线程池创建并获取文件代码实例
2019/09/28 Python
Python 操作mysql数据库查询之fetchone(), fetchmany(), fetchall()用法示例
2019/10/17 Python
构造方法和其他方法的区别?怎么调用父类的构造方法
2013/09/22 面试题
工商管理专业应届生求职信
2013/11/04 职场文书
大学生创业项目方案
2014/03/08 职场文书
做一个有道德的人活动实施方案
2014/08/23 职场文书
重阳节演讲稿:尊敬帮助老人 弘扬传统美德
2014/09/25 职场文书
升学宴学生答谢词
2015/01/05 职场文书
python文件目录操作之os模块
2021/05/08 Python
pytorch实现ResNet结构的实例代码
2021/05/17 Python
python生成可执行exe控制Microsip自动填写号码并拨打功能
2021/06/21 Python