tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
十条建议帮你提高Python编程效率
Feb 16 Python
python3学习笔记之多进程分布式小例子
Feb 13 Python
Django框架使用富文本编辑器Uedit的方法分析
Jul 31 Python
在python 中实现运行多条shell命令
Jan 07 Python
Python基础之条件控制操作示例【if语句】
Mar 23 Python
Python函数和模块的使用总结
May 20 Python
Python学习笔记之Break和Continue用法分析
Aug 14 Python
python requests.get带header
May 05 Python
python Matplotlib基础--如何添加文本和标注
Jan 26 Python
pandas提升计算效率的一些方法汇总
May 30 Python
python神经网络 tf.name_scope 和 tf.variable_scope 的区别
May 04 Python
Django框架中视图的用法
Jun 10 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
php中在PDO中使用事务(Transaction)
2011/05/14 PHP
浅析php适配器模式(Adapter)
2014/11/25 PHP
php专用数组排序类ArraySortUtil用法实例
2015/04/03 PHP
thinkPHP5.0框架自动加载机制分析
2017/03/18 PHP
jquery 表单下所有元素的隐藏
2009/07/25 Javascript
JavaScript 学习笔记(六)
2009/12/31 Javascript
js性能优化 如何更快速加载你的JavaScript页面
2012/03/17 Javascript
jQuery和AngularJS的区别浅析
2015/01/29 Javascript
DOM事件阶段以及事件捕获与事件冒泡先后执行顺序(图文详解)
2015/08/18 Javascript
JS实现超简单的鼠标拖动效果
2015/11/02 Javascript
Bootstrap每天必学之js插件
2015/11/30 Javascript
jQuery模仿单选按钮选中效果
2016/06/24 Javascript
利用jQuery对无序列表排序的简单方法
2016/10/16 Javascript
JS实现页面打印功能
2017/03/16 Javascript
解决vue打包之后静态资源图片失效的问题
2018/02/21 Javascript
vue2 前端搜索实现示例
2018/02/26 Javascript
Vue.js进阶知识点总结
2018/04/01 Javascript
JS与SQL方式随机生成高强度密码示例
2018/12/29 Javascript
UEditor 自定义图片视频尺寸校验功能的实现代码
2020/10/20 Javascript
[02:43]DOTA2英雄基础教程 半人马战行者
2014/01/13 DOTA
使用Python的Twisted框架构建非阻塞下载程序的实例教程
2016/05/25 Python
python中的set实现不重复的排序原理
2018/01/24 Python
PYTHON基础-时间日期处理小结
2018/05/05 Python
使用Python处理Excel表格的简单方法
2018/06/07 Python
python2与python3的print及字符串格式化小结
2018/11/30 Python
Python发展史及网络爬虫
2019/06/19 Python
python模拟菜刀反弹shell绕过限制【推荐】
2019/06/25 Python
Python面向对象原理与基础语法详解
2020/01/02 Python
VELTRA台湾:世界自由行专家
2017/08/15 全球购物
暑假实习求职信范文
2013/09/22 职场文书
无工作经验者个人求职信范文
2013/12/22 职场文书
秋季红领巾广播稿
2014/01/27 职场文书
就业协议书样本
2014/08/20 职场文书
项目经理岗位职责范本
2015/04/01 职场文书
Python初识逻辑与if语句及用法大全
2021/08/07 Python
SQL试题 使用窗口函数选出连续3天登录的用户
2022/04/24 Oracle