TensorFlow模型保存/载入的两种方法


Posted in Python onMarch 08, 2018

TensorFlow 模型保存/载入

我们在上线使用一个算法模型的时候,首先必须将已经训练好的模型保存下来。tensorflow保存模型的方式与sklearn不太一样,sklearn很直接,一个sklearn.externals.joblib的dump与load方法就可以保存与载入使用。而tensorflow由于有graph, operation 这些概念,保存与载入模型稍显麻烦。

一、基本方法

网上搜索tensorflow模型保存,搜到的大多是基本的方法。即

保存

  • 定义变量
  • 使用saver.save()方法保存

载入

  • 定义变量
  • 使用saver.restore()方法载入

保存 代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable([[1,1,1],[2,2,2]],dtype = tf.float32,name='w') 
b = tf.Variable([[0,1,2]],dtype = tf.float32,name='b') 

init = tf.initialize_all_variables() 
saver = tf.train.Saver() 
with tf.Session() as sess: 
  sess.run(init) 
  save_path = saver.save(sess,"save/model.ckpt")

载入代码如下

import tensorflow as tf 
import numpy as np 

W = tf.Variable(tf.truncated_normal(shape=(2,3)),dtype = tf.float32,name='w') 
b = tf.Variable(tf.truncated_normal(shape=(1,3)),dtype = tf.float32,name='b') 

saver = tf.train.Saver() 
with tf.Session() as sess: 
  saver.restore(sess,"save/model.ckpt")

这种方法不方便的在于,在使用模型的时候,必须把模型的结构重新定义一遍,然后载入对应名字的变量的值。但是很多时候我们都更希望能够读取一个文件然后就直接使用模型,而不是还要把模型重新定义一遍。所以就需要使用另一种方法。

二、不需重新定义网络结构的方法

tf.train.import_meta_graph

import_meta_graph(
 meta_graph_or_file,
 clear_devices=False,
 import_scope=None,
 **kwargs
)

这个方法可以从文件中将保存的graph的所有节点加载到当前的default graph中,并返回一个saver。也就是说,我们在保存的时候,除了将变量的值保存下来,其实还有将对应graph中的各种节点保存下来,所以模型的结构也同样被保存下来了。

比如我们想要保存计算最后预测结果的y,则应该在训练阶段将它添加到collection中。具体代码如下

保存

### 定义模型
input_x = tf.placeholder(tf.float32, shape=(None, in_dim), name='input_x')
input_y = tf.placeholder(tf.float32, shape=(None, out_dim), name='input_y')

w1 = tf.Variable(tf.truncated_normal([in_dim, h1_dim], stddev=0.1), name='w1')
b1 = tf.Variable(tf.zeros([h1_dim]), name='b1')
w2 = tf.Variable(tf.zeros([h1_dim, out_dim]), name='w2')
b2 = tf.Variable(tf.zeros([out_dim]), name='b2')
keep_prob = tf.placeholder(tf.float32, name='keep_prob')
hidden1 = tf.nn.relu(tf.matmul(self.input_x, w1) + b1)
hidden1_drop = tf.nn.dropout(hidden1, self.keep_prob)
### 定义预测目标
y = tf.nn.softmax(tf.matmul(hidden1_drop, w2) + b2)
# 创建saver
saver = tf.train.Saver(...variables...)
# 假如需要保存y,以便在预测时使用
tf.add_to_collection('pred_network', y)
sess = tf.Session()
for step in xrange(1000000):
 sess.run(train_op)
 if step % 1000 == 0:
  # 保存checkpoint, 同时也默认导出一个meta_graph
  # graph名为'my-model-{global_step}.meta'.
  saver.save(sess, 'my-model', global_step=step)

载入

with tf.Session() as sess:
 new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
 new_saver.restore(sess, 'my-save-dir/my-model-10000')
 # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
 y = tf.get_collection('pred_network')[0]

 graph = tf.get_default_graph()

 # 因为y中有placeholder,所以sess.run(y)的时候还需要用实际待预测的样本以及相应的参数来填充这些placeholder,而这些需要通过graph的get_operation_by_name方法来获取。
 input_x = graph.get_operation_by_name('input_x').outputs[0]
 keep_prob = graph.get_operation_by_name('keep_prob').outputs[0]

 # 使用y进行预测 
 sess.run(y, feed_dict={input_x:...., keep_prob:1.0})

这里有两点需要注意的:

一、saver.restore()时填的文件名,因为在saver.save的时候,每个checkpoint会保存三个文件,如
my-model-10000.meta, my-model-10000.index, my-model-10000.data-00000-of-00001
import_meta_graph时填的就是meta文件名,我们知道权值都保存在my-model-10000.data-00000-of-00001这个文件中,但是如果在restore方法中填这个文件名,就会报错,应该填的是前缀,这个前缀可以使用tf.train.latest_checkpoint(checkpoint_dir)这个方法获取。

二、模型的y中有用到placeholder,在sess.run()的时候肯定要feed对应的数据,因此还要根据具体placeholder的名字,从graph中使用get_operation_by_name方法获取。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python获取CPU和内存信息的思路与实现(linux系统)
Jan 03 Python
Python黑帽编程 3.4 跨越VLAN详解
Sep 28 Python
python中 logging的使用详解
Oct 25 Python
pandas or sql计算前后两行数据间的增值方法
Apr 20 Python
Django框架验证码用法实例分析
May 10 Python
python数据预处理之数据标准化的几种处理方式
Jul 17 Python
Python 3.8正式发布重要新功能一览
Oct 17 Python
详解centos7+django+python3+mysql+阿里云部署项目全流程
Nov 15 Python
python数值基础知识浅析
Nov 19 Python
Python常用库大全及简要说明
Jan 17 Python
tf.concat中axis的含义与使用详解
Feb 07 Python
解决导入django_filters不成功问题No module named 'django_filter'
Jul 15 Python
python2.7 json 转换日期的处理的示例
Mar 07 #Python
教你用Python创建微信聊天机器人
Mar 31 #Python
为什么入门大数据选择Python而不是Java?
Mar 07 #Python
详解Python中如何写控制台进度条的整理
Mar 07 #Python
python爬虫爬取网页表格数据
Mar 07 #Python
python使用mysql的两种使用方式
Mar 07 #Python
python表格存取的方法
Mar 07 #Python
You might like
什么是调频(FM)、调幅(AM)、短波(SW)、长波(LW)
2021/03/01 无线电
ubuntu 编译安装php 5.3.3+memcache的方法
2010/08/05 PHP
php遍历目录输出目录及其下的所有文件示例
2014/01/27 PHP
PHP的password_hash()使用实例
2014/03/17 PHP
PHP goto语句用法实例
2019/08/06 PHP
Thinkphp 3.2框架使用Redis的方法详解
2019/10/24 PHP
PHP利用curl发送HTTP请求的实例代码
2020/07/09 PHP
JS实现清除指定cookies的方法
2014/09/20 Javascript
javascript判断移动端访问设备并解析对应CSS的方法
2015/02/05 Javascript
js实现滚动条滚动到页面底部继续加载
2015/12/19 Javascript
javascript html5移动端轻松实现文件上传
2020/03/27 Javascript
jQuery实现响应鼠标事件的图片透明效果【附demo源码下载】
2016/06/16 Javascript
EasyUI创建对话框的两种方式
2016/08/23 Javascript
react redux入门示例
2018/04/19 Javascript
Node.js中的cluster模块深入解读
2018/06/11 Javascript
vue2.0使用v-for循环制作多级嵌套菜单栏
2018/06/25 Javascript
关于AngularJS中ng-repeat不更新视图的解决方法
2018/09/30 Javascript
javascript中函数的写法实例代码详解
2018/10/28 Javascript
基于JS实现前端压缩上传图片的实例代码
2019/05/14 Javascript
[51:07]VGJ.S vs Pain 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/20 DOTA
Python查找函数f(x)=0根的解决方法
2015/05/07 Python
selenium+python 去除启动的黑色cmd窗口方法
2018/05/22 Python
Keras:Unet网络实现多类语义分割方式
2020/06/11 Python
Max&Co官网:意大利年轻女性时尚品牌
2017/05/16 全球购物
德国综合购物网站:OTTO
2018/11/13 全球购物
新西兰优惠网站:Treat Me
2019/07/04 全球购物
圣诞树世界:Christmas Tree World
2019/12/10 全球购物
小学教师读书活动总结
2014/07/08 职场文书
关心下一代工作先进事迹
2014/08/15 职场文书
恰同学少年观后感
2015/06/08 职场文书
2015小学音乐教师个人工作总结
2015/07/21 职场文书
穷人该怎么创业?谨记以下几点
2019/07/11 职场文书
学会掌握自己命运的十条黄金法则:
2019/08/08 职场文书
MySQL基础快速入门知识总结(附思维导图)
2021/09/25 MySQL
Java存储没有重复元素的数组
2022/04/29 Java/Android
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
2022/05/04 Python