TensorFlow模型保存/载入的两种方法


Posted in Python onMarch 08, 2018

TensorFlow 模型保存/载入

我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用。而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦。

一、基本方法

网上搜索tensorflow模型保存,搜到的大多是基本的方法。即

保存

  • 定义变量
  • 使用saver.save()方法保存

载入

  • 定义变量
  • 使用saver.restore()方法载入

保存 代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w') 
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') 

init = tf.initialize_all_variables() 
saver = tf.train.Saver() 
with tf.Session() as sess: 
  sess.run(init) 
  save_path = saver.save(sess,"save/model.ckpt")

载入代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w') 
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b') 

saver = tf.train.Saver() 
with tf.Session() as sess: 
  saver.restore(sess,"save/model.ckpt")

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

二、不需重新定义网络结构的方法

tf.train.import_meta_graph

import_meta_graph(
 meta_graph_or_file,
 clear_devices=False,
 import_scope=None,
 **kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下

保存

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
 sess.run(train_op)
 if step % 1000 == 0:
  # 保存checkpoint, 同时也默认导出一个meta_graph
  # graph名为'my-model-{global_step}.meta'.
  saver.save(sess, 'my-model', global_step=step)

载入

with tf.Session() as sess:
 new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
 new_saver.restore(sess, 'my-save-dir/my-model-10000')
 # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
 y = tf.get_collection('pred_network')[0]

 graph = tf.get_default_graph()

 # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
 input_x = graph.get_operation_by_name('input_x').outputs[0]
 keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

 # 使用y进行预测 
 sess.run(y, feed_dict={input_x:...., keep_prob:1.0})

这里有两点需要注意的:

一、saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如
my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pyqt4教程之messagebox使用示例分享
Mar 07 Python
python中使用urllib2获取http请求状态码的代码例子
Jul 07 Python
Python的几个高级语法概念浅析(lambda表达式闭包装饰器)
May 28 Python
总结网络IO模型与select模型的Python实例讲解
Jun 27 Python
Python set常用操作函数集锦
Nov 15 Python
Python中矩阵创建和矩阵运算方法
Aug 04 Python
Python使用LDAP做用户认证的方法
Jun 20 Python
opencv导入头文件时报错#include的解决方法
Jul 31 Python
用Python画小女孩放风筝的示例
Nov 23 Python
Pycharm配置PyQt5环境的教程
Apr 02 Python
python实现移动木板小游戏
Oct 09 Python
详解numpy1.19.4与python3.9版本冲突解决
Dec 15 Python
python2.7 json 转换日期的处理的示例
Mar 07 #Python
教你用Python创建微信聊天机器人
Mar 31 #Python
为什么入门大数据选择Python而不是Java?
Mar 07 #Python
详解Python中如何写控制台进度条的整理
Mar 07 #Python
python爬虫爬取网页表格数据
Mar 07 #Python
python使用mysql的两种使用方式
Mar 07 #Python
python表格存取的方法
Mar 07 #Python
You might like
本地机apache配置基于域名的虚拟主机详解
2013/08/10 PHP
php使用cookie实现记住登录状态
2015/04/27 PHP
thinkPHP中验证码的简单使用方法
2015/12/26 PHP
基于laravel制作APP接口(API)
2016/03/15 PHP
PHP PDOStatement::execute讲解
2019/01/31 PHP
List the Codec Files on a Computer
2007/06/18 Javascript
js获取url中的参数且参数为中文时通过js解码
2014/03/19 Javascript
JavaScript中的操作符==与===介绍
2014/12/31 Javascript
JavaScript显示表单内元素数量的方法
2015/04/02 Javascript
微信jssdk在iframe页面失效问题的解决措施
2016/03/03 Javascript
基于Javascript实现二级联动菜单效果
2016/03/04 Javascript
JS简单获取及显示当前时间的方法
2016/08/03 Javascript
javaScript语法总结
2016/11/25 Javascript
JavaScript函数节流的两种写法
2017/04/07 Javascript
Vue 2.0的数据依赖实现原理代码简析
2017/07/10 Javascript
使用Electron构建React+Webpack桌面应用的方法
2017/12/15 Javascript
jQuery实现的自定义轮播图功能详解
2018/12/28 jQuery
Nodejs中的require函数的具体使用方法
2019/04/02 NodeJs
通过说明与示例了解js五种设计模式
2019/06/17 Javascript
JavaScript 类的封装操作示例详解
2020/05/16 Javascript
vue使用element-ui实现表单验证
2020/12/13 Vue.js
python使用正则搜索字符串或文件中的浮点数代码实例
2014/07/11 Python
Python中的迭代器漫谈
2015/02/03 Python
详细解读tornado协程(coroutine)原理
2018/01/15 Python
对pandas中to_dict的用法详解
2018/06/05 Python
使用Python在Windows下获取USB PID&VID的方法
2019/07/02 Python
Python内建序列通用操作6种实现方法
2020/03/26 Python
python dict乱码如何解决
2020/06/07 Python
Python调用ffmpeg开源视频处理库,批量处理视频
2020/11/16 Python
音乐学院硕士生的自我评价分享
2013/11/01 职场文书
烹饪自我鉴定
2014/03/01 职场文书
超市周年庆活动方案
2014/08/16 职场文书
反四风个人对照检查材料
2014/09/26 职场文书
就业协议书盖章的注意事项
2014/09/28 职场文书
python基础学习之生成器与文件系统知识总结
2021/05/25 Python
使用 Apache 反向代理的设置技巧
2022/01/18 Servers