python使用tensorflow保存、加载和使用模型的方法


Posted in Python onJanuary 31, 2018

使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:

http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

我对这篇文章进行了整理和汇总。

首先是模型的保存。直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut1_save.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 11:04:25 
############################ 
 
import tensorflow as tf 
 
# prepare to feed input, i.e. feed_dict and placeholders 
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration 
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2') 
b1 = tf.Variable(2.0, name = 'bias1') 
feed_dict = {w1:[10,3], w2:[5,5]} 
 
# define a test operation that will be restored 
w3 = tf.add(w1, w2) # without name, w3 will not be stored 
w4 = tf.multiply(w3, b1, name = "op_to_restore") 
 
#saver = tf.train.Saver() 
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
print sess.run(w4, feed_dict) 
#saver.save(sess, 'my_test_model', global_step = 100) 
saver.save(sess, 'my_test_model') 
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)

需要说明的有以下几点:

1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。

2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。

3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut2_import.py 
#Author: Wang  
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:16:38 
############################  
import tensorflow as tf 
sess = tf.Session() 
new_saver = tf.train.import_meta_graph('my_test_model.meta') 
new_saver.restore(sess, tf.train.latest_checkpoint('./')) 
print sess.run('w1:0')

使用加载的模型,输入新数据,计算输出,还是直接上代码:

#!/usr/bin/env python 
#-*- coding:utf-8 -*- 
############################ 
#File Name: tut3_reuse.py 
#Author: Wang 
#Mail: wang19920419@hotmail.com 
#Created Time:2017-08-30 14:33:35 
############################ 
 
import tensorflow as tf 
 
sess = tf.Session() 
 
# First, load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model.meta') 
saver.restore(sess, tf.train.latest_checkpoint('./')) 
 
# Second, access and create placeholders variables and create feed_dict to feed new data 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name('w1:0') 
w2 = graph.get_tensor_by_name('w2:0') 
feed_dict = {w1:[-1,1], w2:[4,6]} 
 
# Access the op that want to run 
op_to_restore = graph.get_tensor_by_name('op_to_restore:0') 
 
print sess.run(op_to_restore, feed_dict)   # ouotput: [6. 14.]

在已经加载的网络后继续加入新的网络层:

import tensorflow as tf 
sess=tf.Session()   
#First let's load meta graph and restore weights 
saver = tf.train.import_meta_graph('my_test_model-1000.meta') 
saver.restore(sess,tf.train.latest_checkpoint('./')) 

# Now, let's access and create placeholders variables and 
# create feed-dict to feed new data 
 
graph = tf.get_default_graph() 
w1 = graph.get_tensor_by_name("w1:0") 
w2 = graph.get_tensor_by_name("w2:0") 
feed_dict ={w1:13.0,w2:17.0} 
 
#Now, access the op that you want to run.  
op_to_restore = graph.get_tensor_by_name("op_to_restore:0") 
 
#Add more to the current graph 
add_on_op = tf.multiply(op_to_restore,2) 
 
print sess.run(add_on_op,feed_dict) 
#This will print 120.

对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):

...... 
...... 
saver = tf.train.import_meta_graph('vgg.meta') 
# Access the graph 
graph = tf.get_default_graph() 
## Prepare the feed_dict for feeding data for fine-tuning  
 
#Access the appropriate output for fine-tuning 
fc7= graph.get_tensor_by_name('fc7:0') 
 
#use this if you only want to change gradients of the last layer 
fc7 = tf.stop_gradient(fc7) # It's an identity function 
fc7_shape= fc7.get_shape().as_list() 
 
new_outputs=2 
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05)) 
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs])) 
output = tf.matmul(fc7, weights) + biases 
pred = tf.nn.softmax(output) 
 
# Now, you run this with fine-tuning data in sess.run()

有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python以环状形式组合排列图片并输出的方法
Mar 17 Python
Python StringIO模块实现在内存缓冲区中读写数据
Apr 08 Python
用python写个自动SSH登录远程服务器的小工具(实例)
Jun 17 Python
python之Character string(实例讲解)
Sep 25 Python
python 平衡二叉树实现代码示例
Jul 07 Python
Python实战购物车项目的实现参考
Feb 20 Python
Python多叉树的构造及取出节点数据(treelib)的方法
Aug 09 Python
python3 requests库实现多图片爬取教程
Dec 18 Python
Python中url标签使用知识点总结
Jan 16 Python
Python类如何定义私有变量
Feb 03 Python
python发qq消息轰炸虐狗好友思路详解(完整代码)
Feb 15 Python
Django中session进行权限管理的使用
Jul 09 Python
python通过elixir包操作mysql数据库实例代码
Jan 31 #Python
Django视图和URL配置详解
Jan 31 #Python
Python编程求质数实例代码
Jan 31 #Python
Python及Django框架生成二维码的方法分析
Jan 31 #Python
Python进阶之尾递归的用法实例
Jan 31 #Python
简单的python协同过滤程序实例代码
Jan 31 #Python
Python进阶之递归函数的用法及其示例
Jan 31 #Python
You might like
php中get_headers函数的作用及用法的详细介绍
2013/04/27 PHP
PHP 文件编程综合案例-文件上传的实现
2013/07/03 PHP
PHP实现添加购物车功能
2017/03/06 PHP
php 生成加密公钥加密私钥实例详解
2017/06/16 PHP
PHP编程获取图片的主色调的方法【基于Imagick扩展】
2017/08/02 PHP
php 可变函数使用小结
2018/06/12 PHP
实例讲解PHP验证邮箱是否合格
2019/01/28 PHP
js 验证密码强弱的小例子
2013/03/21 Javascript
JQuery为textarea添加maxlength属性并且兼容IE
2013/04/25 Javascript
js中运算符&& 和 || 的使用记录
2014/08/21 Javascript
JS模式之简单的订阅者和发布者模式完整实例
2015/06/30 Javascript
JavaScript对数组进行随机重排的方法
2015/07/22 Javascript
jQuery中$.grep() 过滤函数 数组过滤
2016/11/22 Javascript
JS动态添加的div点击跳转到另一页面实现代码
2017/09/30 Javascript
jQuery实现的简单日历组件定义与用法示例
2018/12/24 jQuery
35个最好用的Vue开源库(史上最全)
2019/01/03 Javascript
JavaScript动态创建二维数组的方法示例
2019/02/01 Javascript
vue分页插件的使用方法
2019/12/25 Javascript
[02:56]DOTA2上海特锦赛小组赛解说FreeAgain采访花絮
2016/02/27 DOTA
从零学Python之入门(五)缩进和选择
2014/05/27 Python
Python中字符串的格式化方法小结
2016/05/03 Python
Python实现的rsa加密算法详解
2018/01/24 Python
Python基于多线程实现抓取数据存入数据库的方法
2018/06/22 Python
在python中实现调用可执行文件.exe的3种方法
2019/07/07 Python
python base64库给用户名或密码加密的流程
2020/01/02 Python
python 的topk算法实例
2020/04/02 Python
Python如何给你的程序做性能测试
2020/07/29 Python
Python生成pdf目录书签的实例方法
2020/10/29 Python
有750多个顶级品牌的瑞士时尚在线:ABOUT YOU
2017/01/04 全球购物
ZINVO手表官网:男士和女士手表
2019/03/10 全球购物
电力安全事故反思
2014/04/27 职场文书
警察先进个人事迹材料
2014/05/16 职场文书
区域销售主管岗位职责
2014/06/15 职场文书
vue实现滑动解锁功能
2022/03/03 Vue.js
openstack云计算keystone组件工作介绍
2022/04/20 Servers
js判断两个数组相等的5种方法
2022/05/06 Javascript