python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Django框架中伪造捕捉到的URLconf值的方法
Jul 18 Python
Python使用SQLite和Excel操作进行数据分析
Jan 20 Python
30秒轻松实现TensorFlow物体检测
Mar 14 Python
Python装饰器原理与用法分析
Apr 30 Python
Python操作json的方法实例分析
Dec 06 Python
python腾讯语音合成实现过程解析
Aug 01 Python
Python Django 封装分页成通用的模块详解
Aug 21 Python
python中必要的名词解释
Nov 20 Python
详解python破解zip文件密码的方法
Jan 13 Python
python实现逆滤波与维纳滤波示例
Feb 26 Python
Python3.7 读取音频根据文件名生成脚本的代码
Apr 07 Python
在pycharm中关掉ipython console/PyDev操作
Jun 09 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
天使彦史上最神还原,性别曝光的那一刻,百万网友恋爱了
2020/03/02 国漫
PHP用SAX解析XML的实现代码与问题分析
2011/08/22 PHP
PHP ? EasyUI DataGrid 资料取的方式介绍
2012/11/07 PHP
解析在PHP中使用全局变量的几种方法
2013/06/24 PHP
jQuery Mobile + PHP实现文件上传
2014/12/12 PHP
php给图片加文字水印
2015/07/31 PHP
phpinfo() 中 Local Value(局部变量)Master Value(主变量) 的区别
2016/02/03 PHP
thinkphp框架下404页面设置 仅三步
2016/05/14 PHP
php二维码生成以及下载实现
2017/09/28 PHP
JavaScript 捕获窗口关闭事件
2009/07/26 Javascript
解决jquery中美元符号命名冲突问题
2014/01/08 Javascript
JavaScript 事件绑定及深入
2015/04/13 Javascript
轻松学习Javascript闭包
2017/03/01 Javascript
JS实现预加载视频音频/视频获取截图(返回canvas截图)
2017/10/09 Javascript
JQuery表单元素取值赋值方法总结
2020/05/12 jQuery
微信小程序自定义联系人弹窗
2020/05/26 Javascript
Vue.extend 登录注册模态框的实现
2020/12/29 Vue.js
vue浏览器返回监听的具体步骤
2021/02/03 Vue.js
python实现得到一个给定类的虚函数
2014/09/28 Python
python中ConfigParse模块的用法
2014/09/29 Python
Python的Scrapy爬虫框架简单学习笔记
2016/01/20 Python
Windows安装Python、pip、easy_install的方法
2017/03/05 Python
python的Crypto模块实现AES加密实例代码
2018/01/22 Python
python实现图书馆研习室自动预约功能
2018/04/27 Python
python操作redis方法总结
2018/06/06 Python
详解如何设置Python环境变量?
2019/05/13 Python
python 利用pywifi模块实现连接网络破解wifi密码实时监控网络
2019/09/16 Python
CSS3的resize属性使用初探
2015/09/27 HTML / CSS
让IE支持HTML5的方法
2012/12/11 HTML / CSS
Speedo速比涛中国官方网站:全球领先泳装运动品牌
2018/04/24 全球购物
俄罗斯电动工具和设备购物网站:Vseinstrumenti.ru
2020/11/12 全球购物
vue+django实现下载文件的示例
2021/03/24 Vue.js
架构师岗位职责
2013/11/18 职场文书
公司领导班子对照检查存在问题整改措施
2014/10/02 职场文书
2014年部门工作总结
2014/11/12 职场文书
Go中使用gjson来操作JSON数据的实现
2022/08/14 Golang