tensorflow从ckpt和从.pb文件读取变量的值方式


Posted in Python onMay 26, 2020

最近在学习tensorflow自带的量化工具的相关知识,其中遇到的一个问题是从tensorflow保存好的ckpt文件或者是保存后的.pb文件(这里的pb是把权重和模型保存在一起的pb文件)读取权重,查看量化后的权重是否变成整形。

因此将自己解决这个问题记录下来,为了下一次遇到时,可以有所参考,也希望给有需要的同学一个可能的参考。

(1) 从保存的ckpt读取变量的值(以读取保存的第一个权重为例)

from tensorflow.python import pywrap_tensorflow 
import tensorflow as tf
with tf.Graph().as_default(): 
 with tf.Session() as sess: 
 ckpt = tf.train.get_checkpoint_state('./model_ckpt') #保存ckpt文件的文件夹
 if ckpt and ckpt.model_checkpoint_path: 
 reader = pywrap_tensorflow.NewCheckpointReader('./model_ckpt/model.ckpt-999') #自己保存的ckpt文件名
 all_variables = reader.get_variable_to_shape_map() 
 w1 = reader.get_tensor("Variable_1") 
 print(w1.shape) 
 print(w1) 
 else: print('No checkpoint file found')

(2) 从保存的.pb文件读取变量的值(以读取保存的第一个权重为例)

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
sess = tf.Session()
with gfile.FastGFile('Yourpb.pb', 'rb') as f: #自己保存的pb文件
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') 
 print(sess.run('Variable_1:0'))

补充知识:如何从已存在的检查点文件(cpkt文件)种解析出里面变量——无需重新创建原始计算图

import tensorflow as tf
import os

CheckpointReader

tf.train.NewCheckpointReader是一个创建检查点读取器(CheckpointReader)对象的完美手段。 CheckpointReader中有几个非常有用的方法:

get_variable_to_shape_map() - 提供具有变量名称和形状的字典

debug_string() - 提供由检查点文件中所有变量组成的字符串

has_tensor(var_name) - 允许检查变量是否存在于检查点中

get_tensor(var_name) - 返回变量名称的张量

为了便于说明,我将定义一个函数来检查路径的有效性,并为您加载检查点读取器。

In [3]:

def load_reader(path):
 assert os.path.exists(path), "Provided incorrect path to the file. {} doesn't exist".format(path)
 return tf.train.NewCheckpointReader(path)

In [34]:

your_path = 'logs/squeezeDet1024x1024/train/model.ckpt-0'
reader = load_reader(your_path)

reader.debug_string()

用于返回包含以下内容的一个字符串:

variable name(变量名)

data type(数据类型)

tensor shape(张量类型)

它返回字符串的各元素间均用空格符' '分隔,你可以使用debug_string来创建一个变量名列表,如下所示:

In [53]:

all_var_descriptions = reader.debug_string().split()
var_names, var_shapes = all_var[::3], all_var[2::3]
print(var_names[:4])
print(var_shapes[:4])

输出:

['iou', 'fire9/squeeze1x1/kernels', 'fire9/squeeze1x1/biases', 'fire9/expand3x3/kernels/Momentum']
['[10,36864]', '[1,1,512,64]', '[64]', '[3,3,64,256]']

但是,对于完成同样的任务,更好的方法是使用reader.get_variable_to_shape_map()

reader.get_variable_to_shape_map()

用于返回包含所有变量及其形状名称的字典,变量作为字典的Key,形状作为Value。

In [66]:

saved_shapes = reader.get_variable_to_shape_map()
print('fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels'])
fire9/squeeze1x1/kernels: [1, 1, 512, 64]
reader.has_tensor(var_name)

返回bool值

这是一种方便的方法,允许您检查ckeckpoint中是否存在相关的变量。

In [51]:

names_that_exit = {var_name: reader.has_tensor(var_name) for var_name in var_names[:10]}
for key in names_that_exit:
 print(key.decode()+':', names_that_exit[key])
fire8/squeeze1x1/kernels/Momentum: True
fire9/expand3x3/kernels: True
iou: True
fire9/expand3x3/biases: True
fire9/expand1x1/kernels: True
fire9/expand3x3/kernels/Momentum: True
fire9/expand1x1/biases/Momentum: True
fire9/squeeze1x1/biases: True
fire9/expand1x1/kernels/Momentum: True
fire9/squeeze1x1/kernels: True
reader.get_tensor(tensor_name)

返回包含检查点的张量值的NumPy数组

正常使用方法是先恢复一个张量,然后用恢复的张量初始化你自己的变量:

In [60]:

def recover_var(reader, var_name):
 recovered_var = 'var to be recovered'
 try:
  recovered_var = reader.get_tensor(var_name)
 except:
  assert reader.has_tensor(var_name),\
  "{} variable doesn't exist in the check point. Please check the variable name".format(var_name)
 return recovered_var

In [67]:

checkpoint_var = recover_var(reader, 'conv1/kernels')
print ("Recovered variable has the following shape: \n", checkpoint_var.shape)
new_var = tf.Variable(initial_value=checkpoint_var, name="new_conv1")
print ("New variable will be initialized with recovered values and the following shape: \n", new_var.get_shape())
Recovered variable has the following shape: 
(3, 3, 3, 64)
New variable will be initialized with recovered values and the following shape: 
(3, 3, 3, 64)

以上这篇tensorflow从ckpt和从.pb文件读取变量的值方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用实例解释Python中的继承和多态的概念
Apr 27 Python
python实现对一个完整url进行分割的方法
Apr 29 Python
Python打造出适合自己的定制化Eclipse IDE
Mar 02 Python
Django学习笔记之Class-Based-View
Feb 15 Python
python学生信息管理系统
Mar 13 Python
python+pandas+时间、日期以及时间序列处理方法
Jul 10 Python
python爬虫 execjs安装配置及使用
Jul 30 Python
用Python调用win命令行提高工作效率的实例
Aug 14 Python
解决Python对齐文本字符串问题
Aug 28 Python
Python 获取numpy.array索引值的实例
Dec 06 Python
python matlab库简单用法讲解
Dec 31 Python
matplotlib之多边形选区(PolygonSelector)的使用
Feb 24 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
You might like
PHP采集相关教程之一 CURL函数库
2010/02/15 PHP
PHP页面中文乱码分析
2013/10/29 PHP
解读PHP的Yii框架中请求与响应的处理流程
2016/03/17 PHP
Zend Framework数据库操作技巧总结
2017/02/18 PHP
PHP使用GD库制作验证码的方法(点击验证码或看不清会刷新验证码)
2017/08/15 PHP
php+ajax实现仿百度查询下拉内容功能示例
2017/10/20 PHP
判断某个字符在一个字符串中是否存在的js代码
2014/02/28 Javascript
使用GruntJS构建Web程序之构建篇
2014/06/04 Javascript
jQuery中;function($,undefined) 前面的分号的用处
2014/12/17 Javascript
jQuery中[attribute!=value]选择器用法实例
2014/12/31 Javascript
jQuery+html5实现div弹出层并遮罩背景
2015/04/15 Javascript
基于jQuery实现Accordion手风琴自定义插件
2020/10/13 Javascript
js实现省份下拉菜单效果
2017/02/15 Javascript
ES6中的箭头函数实例详解
2017/04/06 Javascript
微信小程序 实现点击添加移除class
2017/06/12 Javascript
jQuery实现checkbox即点即改批量删除及中间遇到的坑
2017/11/11 jQuery
Vue2.x Todo之自定义指令实现自动聚焦的方法
2019/01/08 Javascript
微信小程序通过一个json实现分享朋友圈图片
2019/09/03 Javascript
详解Vue template 如何支持多个根结点
2020/02/10 Javascript
如何使用gpu.js改善JavaScript的性能
2020/12/01 Javascript
Python常见数据结构详解
2014/07/24 Python
python实现感知器
2017/12/19 Python
python搭建服务器实现两个Android客户端间收发消息
2018/04/12 Python
python re模块的高级用法详解
2018/06/06 Python
Python/ArcPy遍历指定目录中的MDB文件方法
2018/10/27 Python
对python dataframe逻辑取值的方法详解
2019/01/30 Python
基于HTML5+CSS3实现简单的时钟效果
2017/09/11 HTML / CSS
Python面试题:Python里面如何生成随机数
2015/03/12 面试题
现金会计岗位职责
2013/12/05 职场文书
精彩的英文自荐信
2014/01/30 职场文书
网站创业计划书
2014/04/30 职场文书
实习科室评语
2015/01/04 职场文书
花木兰观后感
2015/06/10 职场文书
信息技术远程培训心得体会
2016/01/09 职场文书
python中的plt.cm.Paired用法说明
2021/05/31 Python
Python 视频画质增强
2022/04/28 Python