TensorFlow模型保存/载入的两种方法


Posted in Python onMarch 08, 2018

TensorFlow 模型保存/载入

我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用。而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦。

一、基本方法

网上搜索tensorflow模型保存,搜到的大多是基本的方法。即

保存

  • 定义变量
  • 使用saver.save()方法保存

载入

  • 定义变量
  • 使用saver.restore()方法载入

保存 代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w') 
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') 

init = tf.initialize_all_variables() 
saver = tf.train.Saver() 
with tf.Session() as sess: 
  sess.run(init) 
  save_path = saver.save(sess,"save/model.ckpt")

载入代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w') 
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b') 

saver = tf.train.Saver() 
with tf.Session() as sess: 
  saver.restore(sess,"save/model.ckpt")

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

二、不需重新定义网络结构的方法

tf.train.import_meta_graph

import_meta_graph(
 meta_graph_or_file,
 clear_devices=False,
 import_scope=None,
 **kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下

保存

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
 sess.run(train_op)
 if step % 1000 == 0:
  # 保存checkpoint, 同时也默认导出一个meta_graph
  # graph名为'my-model-{global_step}.meta'.
  saver.save(sess, 'my-model', global_step=step)

载入

with tf.Session() as sess:
 new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
 new_saver.restore(sess, 'my-save-dir/my-model-10000')
 # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
 y = tf.get_collection('pred_network')[0]

 graph = tf.get_default_graph()

 # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
 input_x = graph.get_operation_by_name('input_x').outputs[0]
 keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

 # 使用y进行预测 
 sess.run(y, feed_dict={input_x:...., keep_prob:1.0})

这里有两点需要注意的:

一、saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如
my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python encode和decode的妙用
Sep 02 Python
python中类的一些方法分析
Sep 25 Python
python实现向ppt文件里插入新幻灯片页面的方法
Apr 28 Python
进一步探究Python中的正则表达式
Apr 28 Python
用virtualenv建立多个Python独立虚拟开发环境
Jul 06 Python
浅谈python中统计计数的几种方法和Counter详解
Nov 07 Python
python将四元数变换为旋转矩阵的实例
Dec 04 Python
PyCharm刷新项目(文件)目录的实现
Feb 14 Python
python网络编程socket实现服务端、客户端操作详解
Mar 24 Python
python判断是空的实例分享
Jul 06 Python
python使用re模块爬取豆瓣Top250电影
Oct 20 Python
Python+logging输出到屏幕将log日志写入文件
Nov 11 Python
python2.7 json 转换日期的处理的示例
Mar 07 #Python
教你用Python创建微信聊天机器人
Mar 31 #Python
为什么入门大数据选择Python而不是Java?
Mar 07 #Python
详解Python中如何写控制台进度条的整理
Mar 07 #Python
python爬虫爬取网页表格数据
Mar 07 #Python
python使用mysql的两种使用方式
Mar 07 #Python
python表格存取的方法
Mar 07 #Python
You might like
WordPress中"无法将上传的文件移动至"错误的解决方法
2015/07/01 PHP
无需数据库在线投票调查php代码
2016/07/20 PHP
Symfony查询方法实例小结
2017/06/28 PHP
PHP实现微信提现功能
2018/09/30 PHP
Laravel-添加后台模板AdminLte的实现方法
2019/10/08 PHP
关于laravel5.5的定时任务详解(demo)
2019/10/23 PHP
js 颜色选择器(兼容firefox)
2009/03/05 Javascript
extjs表格文本启用选择复制功能具体实现
2013/10/11 Javascript
JS简单实现登陆验证附效果图
2013/11/19 Javascript
js计算两个时间之间天数差的实例代码
2013/11/19 Javascript
验证码在IE中不刷新而谷歌等浏览器正常的解决方案
2014/03/18 Javascript
探究Javascript模板引擎mustache.js使用方法
2016/01/26 Javascript
基于jquery实现轮播焦点图插件
2016/03/31 Javascript
AngularJS中的API(接口)简单实现
2016/07/28 Javascript
bootstrap模态框消失问题的解决方法
2016/12/02 Javascript
JavaScript截屏功能的实现代码
2017/07/28 Javascript
es6中的解构赋值、扩展运算符和rest参数使用详解
2017/09/28 Javascript
vue2.0项目实现路由跳转的方法详解
2018/06/21 Javascript
vue-router之解决addRoutes使用遇到的坑
2020/07/19 Javascript
Python闭包的两个注意事项(推荐)
2017/03/20 Python
Python使用装饰器模拟用户登陆验证功能示例
2018/08/24 Python
python机器学习之KNN分类算法
2018/08/29 Python
python多任务及返回值的处理方法
2019/01/22 Python
python3 下载网络图片代码实例
2019/08/27 Python
使用HTML5和CSS3表单验证功能
2017/05/05 HTML / CSS
HTML5重塑Web世界它将如何改变互联网
2012/12/17 HTML / CSS
幼师自我鉴定范文
2013/10/01 职场文书
高中生职业规划范文
2014/03/09 职场文书
学雷锋宣传标语
2014/06/25 职场文书
机关干部个人对照检查材料思想汇报
2014/09/28 职场文书
民间借贷纠纷案件代理词
2015/05/26 职场文书
讲座开场白台词和结束语
2015/05/29 职场文书
加薪申请书应该这样写!
2019/07/04 职场文书
python随机打印成绩排名表
2021/06/23 Python
Mysql binlog日志文件过大的解决
2021/10/05 MySQL
mysql查看表结构的三种方法总结
2022/07/07 MySQL