tensorflow实现从.ckpt文件中读取任意变量


Posted in Python onMay 26, 2020

思路有些混乱,希望大家能理解我的意思。

看了faster rcnn的tensorflow代码,关于fix_variables的作用我不是很明白,所以写了以下代码,读取了预训练模型vgg16得fc6和fc7的参数,以及faster rcnn中heat_to_tail中的fc6和fc7,将它们做了对比,发现结果不一样,说明vgg16的fc6和fc7只是初始化了faster rcnn中heat_to_tail中的fc6和fc7,之后后者被训练。

具体读取任意变量的代码如下:

import tensorflow as tf
import numpy as np
from tensorflow.python import pywrap_tensorflow
 
file_name = '/home/dl/projectBo/tf-faster-rcnn/data/imagenet_weights/vgg16.ckpt' #.ckpt的路径
name_variable_to_restore = 'vgg_16/fc7/weights' #要读取权重的变量名
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
print('shape', var_to_shape_map[name_variable_to_restore]) #输出这个变量的尺寸
fc7_conv = tf.get_variable("fc7", var_to_shape_map[name_variable_to_restore], trainable=False) # 定义接收权重的变量名
restorer_fc = tf.train.Saver({name_variable_to_restore: fc7_conv }) #定义恢复变量的对象
sess = tf.Session()
sess.run(tf.variables_initializer([fc7_conv], name='init')) #必须初始化
restorer_fc.restore(sess, file_name) #恢复变量
print(sess.run(fc7_conv)) #输出结果

用以上的代码分别读取两个网络的fc6 和 fc7 ,对应参数尺寸和权值都不同,但参数量相同。

再看lib/nets/vgg16.py中的:

(注意注释)

def fix_variables(self, sess, pretrained_model):
 print('Fix VGG16 layers..')
 with tf.variable_scope('Fix_VGG16') as scope:
  with tf.device("/cpu:0"):
   # fix the vgg16 issue from conv weights to fc weights
   # fix RGB to BGR
   fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)      
   fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
   conv1_rgb = tf.get_variable("conv1_rgb", [3, 3, 3, 64], trainable=False)   #定义接收权重的变量,不可被训练
   restorer_fc = tf.train.Saver({self._scope + "/fc6/weights": fc6_conv, 
                  self._scope + "/fc7/weights": fc7_conv,
                  self._scope + "/conv1/conv1_1/weights": conv1_rgb}) #定义恢复变量的对象
   restorer_fc.restore(sess, pretrained_model) #恢复这些变量
 
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc6/weights:0'], tf.reshape(fc6_conv, 
             self._variables_to_fix[self._scope + '/fc6/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/fc7/weights:0'], tf.reshape(fc7_conv, 
             self._variables_to_fix[self._scope + '/fc7/weights:0'].get_shape())))
   sess.run(tf.assign(self._variables_to_fix[self._scope + '/conv1/conv1_1/weights:0'], 
             tf.reverse(conv1_rgb, [2])))         #将vgg16中的fc6、fc7中的权重reshape赋给faster-rcnn中的fc6、fc7

我的理解:faster rcnn的网络继承了分类网络的特征提取权重和分类器的权重,让网络从一个比较好的起点开始被训练,有利于训练结果的快速收敛。

补充知识:TensorFlow:加载部分ckpt文件变量&不同命名空间中加载模型

TensorFlow中,在加载和保存模型时,一般会直接使用tf.train.Saver.restore()和tf.train.Saver.save()

然而,当需要选择性加载模型参数时,则需要利用pywrap_tensorflow读取模型,分析模型内的变量关系。

例子:Faster-RCNN中,模型加载vgg16.ckpt,需要利用pywrap_tensorflow读取ckpt文件中的参数

from tensorflow.python import pywrap_tensorflow
 
model=VGG16()#此处构建vgg16模型
variables = tf.global_variables()#获取模型中所有变量
 
file_name='vgg16.ckpt'#vgg16网络模型
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()#获取ckpt模型中的变量名
print(var_to_shape_map)
 
sess=tf.Session()
 
my_scope='my/'#外加的空间名
variables_to_restore={}#构建字典:需要的变量和对应的模型变量的映射
for v in variables:
  if my_scope in v.name and v.name.split(':')[0].split(my_scope)[1] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name.split(':0')[0][len(my_scope):]]=v
  elif v.name.split(':')[0] in var_to_shape_map:
    print('Variables restored: %s' % v.name)
    variables_to_restore[v.name]=v
 
restorer=tf.train.Saver(variables_to_restore)#将需要加载的变量作为参数输入
restorer.restore(sess, file_name)

实际中,Faster RCNN中所构建的vgg16网络的fc6和fc7权重shape如下:

<tf.Variable 'my/vgg_16/fc6/weights:0' shape=(25088, 4096) dtype=float32_ref>,
<tf.Variable 'my/vgg_16/fc7/weights:0' shape=(4096, 4096) dtype=float32_ref>,

vgg16.ckpt的fc6,fc7权重shape如下:

'vgg_16/fc6/weights': [7, 7, 512, 4096],
'vgg_16/fc7/weights': [1, 1, 4096, 4096],

因此,有如下操作:

fc6_conv = tf.get_variable("fc6_conv", [7, 7, 512, 4096], trainable=False)
fc7_conv = tf.get_variable("fc7_conv", [1, 1, 4096, 4096], trainable=False)
        
restorer_fc = tf.train.Saver({"vgg_16/fc6/weights": fc6_conv,
               "vgg_16/fc7/weights": fc7_conv,
               })
restorer_fc.restore(sess, pretrained_model)
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc6/weights:0'], tf.reshape(fc6_conv,self._variables_to_fix['my/vgg_16/fc6/weights:0'].get_shape())))  
sess.run(tf.assign(self._variables_to_fix['my/vgg_16/fc7/weights:0'], tf.reshape(fc7_conv,self._variables_to_fix['my/vgg_16/fc7/weights:0'].get_shape())))

以上这篇tensorflow实现从.ckpt文件中读取任意变量就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中encode()方法的使用简介
May 18 Python
详解Python中的变量及其命名和打印
Mar 11 Python
python实现聊天小程序
Mar 13 Python
朴素贝叶斯分类算法原理与Python实现与使用方法案例
Jun 26 Python
python线程安全及多进程多线程实现方法详解
Sep 27 Python
python3中rank函数的用法
Nov 27 Python
Pytorch之保存读取模型实例
Dec 30 Python
TFRecord格式存储数据与队列读取实例
Jan 21 Python
PageFactory设计模式基于python实现
Apr 14 Python
更新升级python和pip版本后不生效的问题解决
Apr 17 Python
keras .h5转移动端的.tflite文件实现方式
May 25 Python
Python 类,对象,数据分类,函数参数传递详解
Sep 25 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
You might like
一个程序下载的管理程序(三)
2006/10/09 PHP
AJAX的跨域访问-两种有效的解决方法介绍
2013/06/22 PHP
PHP编程中的__clone()方法使用详解
2015/11/27 PHP
PHP基于ip2long实现IP转换整形
2020/12/11 PHP
网易JS面试题与Javascript词法作用域说明
2010/11/09 Javascript
jQuery学习笔记之DOM对象和jQuery对象
2010/12/22 Javascript
jquery事件机制扩展插件 jquery鼠标右键事件。
2011/12/26 Javascript
jquery实现输入框动态增减的实例代码
2013/07/14 Javascript
页面载入结束自动调用js函数示例
2013/09/23 Javascript
js判断当页面无法回退时关闭网页否则就history.go(-1)
2014/08/07 Javascript
JavaScript实现找质数代码分享
2015/03/24 Javascript
Javascript小技能总结(推荐)
2016/06/02 Javascript
EasyUI为Numberbox添加blur事件的方法
2017/03/05 Javascript
Vue.js实战之利用vue-router实现跳转页面
2017/04/01 Javascript
解决bootstrap下拉菜单点击立即隐藏bug的方法
2017/06/13 Javascript
详解webpack babel的配置
2018/01/09 Javascript
浅谈Vue下使用百度地图的简易方法
2018/03/23 Javascript
解决bootstrap-select 动态加载数据不显示的问题
2018/08/10 Javascript
详解Vue-Router源码分析路由实现原理
2019/05/15 Javascript
Javascript Dom元素获取和添加详解
2019/09/24 Javascript
python完成FizzBuzzWhizz问题(拉勾网面试题)示例
2014/05/05 Python
python实现带声音的摩斯码翻译实现方法
2015/05/20 Python
PIL对上传到Django的图片进行处理并保存的实例
2019/08/07 Python
Python基于gevent实现高并发代码实例
2020/05/15 Python
Spring http服务远程调用实现过程解析
2020/06/11 Python
Python字节单位转换(将字节转换为K M G T)
2021/03/02 Python
使用简单的CSS3属性实现炫酷读者墙效果
2014/01/08 HTML / CSS
英国Radley包德国官网:Radley London德国
2019/11/18 全球购物
Internet体系结构
2014/12/21 面试题
九州传奇上机题
2014/07/10 面试题
主题班会演讲稿
2014/05/22 职场文书
小学标准化建设汇报材料
2014/08/16 职场文书
教师工作证明范本
2015/06/12 职场文书
Oracle更换为MySQL遇到的问题及解决
2021/05/21 Oracle
使用jpa之动态插入与修改(重写save)
2021/11/23 Java/Android
Nginx配置使用详解
2022/07/07 Servers