python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现问号表达式(?)的方法
Nov 27 Python
python处理圆角图片、圆形图片的例子
Apr 25 Python
Python实现Const详解
Jan 27 Python
Python中的默认参数详解
Jun 24 Python
Python实现好友全头像的拼接实例(推荐)
Jun 24 Python
pandas object格式转float64格式的方法
Apr 10 Python
详解Python 爬取13个旅游城市,告诉你五一大家最爱去哪玩?
May 07 Python
Django 后台获取文件列表 InMemoryUploadedFile的例子
Aug 07 Python
python使用opencv实现马赛克效果示例
Sep 28 Python
python实现树的深度优先遍历与广度优先遍历详解
Oct 26 Python
Python3并发写文件与Python对比
Nov 20 Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
PHP读写文件的方法(生成HTML)
2006/11/27 PHP
在Laravel框架里实现发送邮件实例(邮箱验证)
2016/05/20 PHP
yii2-GridView在开发中常用的功能及技巧总结
2017/01/07 PHP
Yii2框架数据验证操作实例详解
2018/05/02 PHP
laravel框架如何设置公共头和公共尾
2019/10/22 PHP
Javascript String.replace的妙用
2009/09/08 Javascript
半角全角相互转换的js函数
2009/10/16 Javascript
导入extjs、jquery 文件时$使用冲突问题解决方法
2014/01/14 Javascript
node.js中的http.get方法使用说明
2014/12/14 Javascript
JavaScript模拟数组合并concat
2016/03/06 Javascript
vue分类筛选filter方法简单实例
2017/03/30 Javascript
鼠标拖动改变DIV等网页元素的大小的实现方法
2017/07/06 Javascript
详解关于react-redux中的connect用法介绍及原理解析
2017/09/11 Javascript
JavaScript强制类型转换和隐式类型转换操作示例
2019/05/01 Javascript
Vue中对iframe实现keep alive无刷新的方法
2019/07/23 Javascript
微信小程序自定义navigationBar顶部导航栏适配所有机型(附完整案例)
2020/04/26 Javascript
python 字符串转列表 list 出现\ufeff的解决方法
2017/06/22 Python
python处理Excel xlrd的简单使用
2017/09/12 Python
Python 实现数据结构中的的栈队列
2019/05/16 Python
Django文件存储 默认存储系统解析
2019/08/02 Python
Python3 文章标题关键字提取的例子
2019/08/26 Python
使用OpenCV实现仿射变换—缩放功能
2019/08/29 Python
基于python调用psutil模块过程解析
2019/12/20 Python
Python日志logging模块功能与用法详解
2020/04/09 Python
Python3爬虫关于识别点触点选验证码的实例讲解
2020/07/30 Python
Python logging日志库空间不足问题解决
2020/09/14 Python
越南电子产品购物网站:FPT Shop
2017/12/02 全球购物
美国在线乐器和设备商店:Musician’s Friend
2018/07/06 全球购物
索引覆盖(Index Covering)查询含义
2012/02/18 面试题
如何写一封打动人心的求职信
2014/02/17 职场文书
高中军训感言400字
2014/02/24 职场文书
中学生爱国演讲稿
2014/09/05 职场文书
学校党的群众路线教育实践活动整改措施
2014/10/25 职场文书
2014年班组建设工作总结
2014/12/01 职场文书
Python中OpenCV实现简单车牌字符切割
2021/06/11 Python
zabbix自定义监控nginx状态实现过程
2021/11/01 Servers