python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python批量提交沙箱问题实例
Oct 08 Python
深入剖析Python的爬虫框架Scrapy的结构与运作流程
Jan 20 Python
Python标准库06之子进程 (subprocess包) 详解
Dec 07 Python
Python win32com 操作Exce的l简单方法(必看)
May 25 Python
sublime text 3配置使用python操作方法
Jun 11 Python
Python自动化开发学习之三级菜单制作
Jul 14 Python
matplotlib简介,安装和简单实例代码
Dec 26 Python
Python爬虫之正则表达式基本用法实例分析
Aug 08 Python
通过python将大量文件按修改时间分类的方法
Oct 17 Python
Python多继承以及MRO顺序的使用
Nov 11 Python
MoviePy简介及Python视频剪辑自动化
Dec 18 Python
python如何利用cv2.rectangle()绘制矩形框
Dec 24 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
个人写的PHP验证码生成类分享
2014/08/21 PHP
ECSHOP在PHP5.5及高版本上报错的解决方法
2015/08/31 PHP
php 从一个数组中随机的取出若干个不同的数实例
2016/12/31 PHP
轻轻松松学习JavaScript
2007/02/25 Javascript
JS提交并解析后台返回的XML的代码
2008/11/03 Javascript
Jquery AJAX 用于计算点击率(统计)
2010/06/30 Javascript
Jquery 选中表格一列并对表格排序实现原理
2012/12/15 Javascript
JS实现div居中示例
2014/04/17 Javascript
JavaScript必知必会(三) String .的方法来自何方
2016/06/08 Javascript
jquery dataview数据视图插件使用方法
2016/12/23 Javascript
使用vue构建一个上传图片表单
2017/07/04 Javascript
引入JavaScript时alert弹出框显示中文乱码问题
2017/09/16 Javascript
Vue.use源码学习小结
2018/06/20 Javascript
Javascript格式化并高亮xml字符串的方法及注意事项
2018/08/13 Javascript
JavaScript console的使用方法实例分析
2020/04/28 Javascript
vue 保留两位小数 不能直接用toFixed(2) 的解决
2020/08/07 Javascript
antd-DatePicker组件获取时间值,及相关设置方式
2020/10/27 Javascript
Python实现的RSS阅读器实例
2015/07/25 Python
Windows系统下PhantomJS的安装和基本用法
2018/10/21 Python
python实现复制文件到指定目录
2019/10/16 Python
浅谈Python协程
2020/06/17 Python
python爬虫请求头设置代码
2020/07/28 Python
关于Python3爬虫利器Appium的安装步骤
2020/07/29 Python
Python实现数字的格式化输出
2020/08/01 Python
python中的split、rsplit、splitlines用法说明
2020/10/23 Python
python分布式爬虫中消息队列知识点详解
2020/11/26 Python
pytorch 实现L2和L1正则化regularization的操作
2021/03/03 Python
美国电子产品折扣网站:Daily Steals
2017/05/20 全球购物
AJAX检测用户名是否存在的方法
2021/03/24 Javascript
新手上路标语
2014/06/20 职场文书
机关党建工作汇报材料
2014/08/20 职场文书
个人作风建设总结
2014/10/23 职场文书
个人工作总结范文2014
2014/11/07 职场文书
2016年校园植树节广播稿
2015/12/17 职场文书
2016年师德先进个人事迹材料
2016/02/29 职场文书
家长必看:义务教育,不得以面试 评测等名义选拔学生
2019/07/09 职场文书