tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中encode()方法的使用简介
May 18 Python
scrapy spider的几种爬取方式实例代码
Jan 25 Python
python针对不定分隔符切割提取字符串的方法
Oct 26 Python
django框架防止XSS注入的方法分析
Jun 21 Python
简单了解python的一些位运算技巧
Jul 13 Python
pygame实现俄罗斯方块游戏(基础篇2)
Oct 29 Python
Pytorch之parameters的使用
Dec 31 Python
win10安装tesserocr配置 Python使用tesserocr识别字母数字验证码
Jan 16 Python
解决python 执行sql语句时所传参数含有单引号的问题
Jun 06 Python
Python绘图之柱形图绘制详解
Jul 28 Python
python 使用Tensorflow训练BP神经网络实现鸢尾花分类
May 12 Python
Python实现双向链表
May 25 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
用Socket发送电子邮件
2006/10/09 PHP
同一空间绑定多个域名而实现访问不同页面的PHP代码
2006/12/06 PHP
彻底搞懂PHP 变量结构体
2017/10/11 PHP
PHP生成(支持多模板)二维码海报代码
2018/04/30 PHP
PHP聊天室简单实现方法详解
2018/12/08 PHP
javascript中callee与caller的用法和应用场景
2010/12/08 Javascript
奉献给JavaScript初学者的编写开发的七个细节
2011/01/11 Javascript
JS 无限级 Select效果实现代码(json格式)
2011/08/30 Javascript
JavaScript中两个感叹号的作用说明
2011/12/28 Javascript
javascript 上下banner替换具体实现
2013/11/14 Javascript
JS匀速运动演示示例代码
2013/11/26 Javascript
javascript数组随机排序实例分析
2015/07/22 Javascript
JavaScript中的数据类型转换方法小结
2015/10/26 Javascript
jquery简单插件制作(fn.extend)完整实例
2016/05/24 Javascript
JS控制TreeView的结点选择
2016/11/11 Javascript
jQuery使用正则表达式替换dom元素标签用法示例
2017/01/16 Javascript
Angular.js中ng-if、ng-show和ng-hide的区别介绍
2017/01/20 Javascript
微信小程序使用navigateTo数据传递的实例
2017/09/26 Javascript
this.$toast() 了解一下?
2019/04/18 Javascript
浅析vue-router中params和query的区别
2019/12/24 Javascript
Vue是怎么渲染template内的标签内容的
2020/06/05 Javascript
[56:00]2018DOTA2亚洲邀请赛 4.6 淘汰赛 VP vs TNC 第二场
2018/04/10 DOTA
实例讲解Python的函数闭包使用中应注意的问题
2016/06/20 Python
Python 在字符串中加入变量的实例讲解
2018/05/02 Python
python 定义n个变量方法 (变量声明自动化)
2018/11/10 Python
Python supervisor强大的进程管理工具的使用
2019/04/24 Python
keras的backend 设置 tensorflow,theano操作
2020/06/30 Python
教师专业自荐书范文
2014/02/10 职场文书
班级学雷锋活动总结
2014/06/26 职场文书
乡镇党的群众路线教育实践活动个人对照检查材料
2014/09/23 职场文书
中共广东省委常委会党的群众路线教育实践活动整改方案
2014/09/23 职场文书
八一建军节主持词
2015/07/01 职场文书
2015小学音乐教师个人工作总结
2015/07/21 职场文书
只需要这一行代码就能让python计算速度提高十倍
2021/05/24 Python
React列表栏及购物车组件使用详解
2021/06/28 Javascript
使用Apache Camel表达REST服务的方法
2022/06/10 Servers