tensorflow saver 保存和恢复指定 tensor的实例讲解


Posted in Python onJuly 26, 2018

在实践中经常会遇到这样的情况:

1、用简单的模型预训练参数

2、把预训练的参数导入复杂的模型后训练复杂的模型

这时就产生一个问题:

如何加载预训练的参数。

下面就是我的总结。

为了方便说明,做一个假设:简单的模型只有一个卷基层,复杂模型有两个。

卷积层的实现代码如下:

import tensorflow as tf
# PS:本篇的重担是saver,不过为了方便阅读还是说明下参数
# 参数
# name:创建卷基层的代码这么多,必须要函数化,而为了防止变量冲突就需要用tf.name_scope
# input_data:输入数据
# width, high:卷积小窗口的宽、高
# deep_before, deep_after:卷积前后的神经元数量
# stride:卷积小窗口的移动步长
def make_conv(name, input_data, width, high, deep_before,deep_after, stride, padding_type='SAME'):
 global parameters
 with tf.name_scope(name) asscope:
  weights =tf.Variable(tf.truncated_normal([width, high, deep_before, deep_after],
   dtype=tf.float32,stddev=0.01), trainable=True, name='weights')
  biases =tf.Variable(tf.constant(0.1, shape=[deep_after]), trainable=True, name='biases')
  conv =tf.nn.conv2d(input_data, weights, [1, stride, stride, 1], padding=padding_type)
  bias = tf.add(conv,biases)
  bias = batch_norm(bias,deep_after, 1) # batch_norm是自己写的batchnorm函数
  conv =tf.maximum(0.1*bias, bias)
  return conv

简单的预训练模型就下面一句话

conv1 =make_conv('simple-conv1', images, 3, 3, 3, 32, 1)

复杂的模型是两个卷基层,如下:

conv1 = make_conv('complex-conv1',images, 3, 3, 3, 32, 1)
pool1= make_max_pool('layer1-pool1', conv1, 2, 2)
conv2= make_conv('complex-conv2', pool1, 3, 3, 32, 64, 1)

这时简简单单的在预训练模型中:

saver = tf.train.Saver()
with tf.Session() as sess:
saver.save(sess,'model.ckpt')

就不行了,因为:

1,如果你在预训练模型中使用下面的话打印所有tensor

all_v =tf.global_variables()
for i in all_v: print i

会发现tensor的名字不是weights和biases,而是'simple-conv1/weights和'simple-conv1/biases,如下:

<tf.Variable'simple-conv1/weights:0' shape=(3, 3, 3, 32) dtype=float32_ref>

<tf.Variable'simple-conv1/biases:0' shape=(32,) dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_1:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_2:0' shape=(32,)dtype=float32_ref>

<tf.Variable 'simple-conv1/Variable_3:0' shape=(32,)dtype=float32_ref>

同理,在复杂模型中就是complex-conv1/weights和complex-conv1/biases,这是对不上号的。

2,预训练模型中只有1个卷积层,而复杂模型中有两个,而tensorflow默认会从模型文件('model.ckpt')中找所有的“可训练的”tensor,找不到会报错。

解决方法:

1,在预训练模型中定义全局变量

parm_dict={}

并在“return conv”上面添加下面两行

parm_dict['complex-conv1/weights']= weights
parm_dict['complex-conv1/']= biases

然后在定义saver时使用下面这句话:

saver= tf.train.Saver(parm_dict)

这样保存后的模型文件就对应到复杂模型上了。

2,在复杂模型中定义全局变量

parameters= []

并在“return conv”上面添加下面行

parameters+= [weights, biases]

然后判断如果是第二个卷积层就不更新parameters。

接着在定义saver时使用下面这句话:

saver= tf.train.Saver(parameters)

这样就可以告诉saver,只需要从模型文件中找weights和biases,而那些什么complex-conv1/Variable~ complex-conv1/Variable_3统统滚一边去(上面红色部分)。

最后使用下面的代码加载就可以了

with tf.Session() as sess:
 ckpt= tf.train.get_checkpoint_state('.')
 if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess,ckpt.model_checkpoint_path)
 else:
  print ' no saver.'
  exit()

以上这篇tensorflow saver 保存和恢复指定 tensor的实例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python抓取京东图书评论数据
Aug 31 Python
在Python中操作时间之strptime()方法的使用
Dec 30 Python
Python实现求最大公约数及判断素数的方法
May 26 Python
浅析Python中元祖、列表和字典的区别
Aug 17 Python
python3中str(字符串)的使用教程
Mar 23 Python
python爬虫获取淘宝天猫商品详细参数
Jun 23 Python
Python3.遍历某文件夹提取特定文件名的实例
Apr 26 Python
python数字图像处理之高级形态学处理
Apr 27 Python
Pytorch反向求导更新网络参数的方法
Aug 17 Python
复化梯形求积分实例——用Python进行数值计算
Nov 20 Python
Python2和Python3中@abstractmethod使用方法
Feb 04 Python
使用Python文件读写,自定义分隔符(custom delimiter)
Jul 05 Python
python opencv旋转图像(保持图像不被裁减)
Jul 26 #Python
详解Django中间件的5种自定义方法
Jul 26 #Python
python opencv实现切变换 不裁减图片
Jul 26 #Python
Flask之flask-script模块使用
Jul 26 #Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 #Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 #Python
Python迭代器与生成器基本用法分析
Jul 26 #Python
You might like
eaglephp使用微信api接口开发微信框架
2014/01/09 PHP
PHP生成短网址的3种方法代码实例
2014/07/08 PHP
php中mysql连接方式PDO使用详解
2015/02/25 PHP
php实现汉字验证码和算式验证码的方法
2015/03/07 PHP
PHP控制前台弹出对话框的实现方法
2016/08/21 PHP
Redis构建分布式锁
2017/03/28 PHP
HTML 自动伸缩的表格Table js实现
2009/04/01 Javascript
JS操作Cookie写入和读取实例代码
2013/10/20 Javascript
jQuery中多个元素的Hover事件解决方案
2014/06/12 Javascript
flash+jQuery实现可关闭及重复播放的压顶广告
2015/04/15 Javascript
JavaScript中解析JSON数据的三种方法
2015/07/03 Javascript
一篇文章掌握RequireJS常用知识
2016/01/26 Javascript
JavaScript中最容易混淆的作用域、提升、闭包知识详解(推荐)
2016/09/05 Javascript
Validform表单验证总结篇
2016/10/31 Javascript
原生Aajax 和jQuery Ajax 写法个人总结
2017/03/24 jQuery
浅探express路由和中间件的实现
2019/09/30 Javascript
javascript使用Blob对象实现的下载文件操作示例
2020/04/18 Javascript
vue实现简易的双向数据绑定
2020/12/29 Vue.js
详解tensorflow训练自己的数据集实现CNN图像分类
2018/02/07 Python
python装饰器深入学习
2018/04/06 Python
Python进阶之自定义对象实现切片功能
2019/01/07 Python
Falsk 与 Django 过滤器的使用与区别详解
2019/06/04 Python
python 缺失值处理的方法(Imputation)
2019/07/02 Python
python实现基于朴素贝叶斯的垃圾分类算法
2019/07/09 Python
CSS3等相关属性制作分页导航实现代码
2012/12/24 HTML / CSS
深入理解css属性的选择对动画性能的影响
2016/04/20 HTML / CSS
使用html2canvas实现将html内容写入到canvas中生成图片
2020/01/03 HTML / CSS
奥地利度假券的专家:we-are.travel
2019/04/10 全球购物
创业资金计划书
2014/02/06 职场文书
教师网络培训感言
2014/03/09 职场文书
2015年信息宣传工作总结
2015/05/26 职场文书
2015年行政管理人员工作总结
2015/10/15 职场文书
奖学金主要事迹范文
2015/11/04 职场文书
30岁前绝不能错过的10本书
2019/08/08 职场文书
快消品行业营销模式与盈利模式分享
2019/09/27 职场文书
python flask开发的简单基金查询工具
2021/06/02 Python