tensorflow从ckpt和从.pb文件读取变量的值方式


Posted in Python onMay 26, 2020

最近在学习tensorflow自带的量化工具的相关知识,其中遇到的一个问题是从tensorflow保存好的ckpt文件或者是保存后的.pb文件(这里的pb是把权重和模型保存在一起的pb文件)读取权重,查看量化后的权重是否变成整形。

因此将自己解决这个问题记录下来,为了下一次遇到时,可以有所参考,也希望给有需要的同学一个可能的参考。

(1) 从保存的ckpt读取变量的值(以读取保存的第一个权重为例)

from tensorflow.python import pywrap_tensorflow 
import tensorflow as tf
with tf.Graph().as_default(): 
 with tf.Session() as sess: 
 ckpt = tf.train.get_checkpoint_state('./model_ckpt') #保存ckpt文件的文件夹
 if ckpt and ckpt.model_checkpoint_path: 
 reader = pywrap_tensorflow.NewCheckpointReader('./model_ckpt/model.ckpt-999') #自己保存的ckpt文件名
 all_variables = reader.get_variable_to_shape_map() 
 w1 = reader.get_tensor("Variable_1") 
 print(w1.shape) 
 print(w1) 
 else: print('No checkpoint file found')

(2) 从保存的.pb文件读取变量的值(以读取保存的第一个权重为例)

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
sess = tf.Session()
with gfile.FastGFile('Yourpb.pb', 'rb') as f: #自己保存的pb文件
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def, name='') 
 print(sess.run('Variable_1:0'))

补充知识:如何从已存在的检查点文件(cpkt文件)种解析出里面变量——无需重新创建原始计算图

import tensorflow as tf
import os

CheckpointReader

tf.train.NewCheckpointReader是一个创建检查点读取器(CheckpointReader)对象的完美手段。 CheckpointReader中有几个非常有用的方法:

get_variable_to_shape_map() - 提供具有变量名称和形状的字典

debug_string() - 提供由检查点文件中所有变量组成的字符串

has_tensor(var_name) - 允许检查变量是否存在于检查点中

get_tensor(var_name) - 返回变量名称的张量

为了便于说明,我将定义一个函数来检查路径的有效性,并为您加载检查点读取器。

In [3]:

def load_reader(path):
 assert os.path.exists(path), "Provided incorrect path to the file. {} doesn't exist".format(path)
 return tf.train.NewCheckpointReader(path)

In [34]:

your_path = 'logs/squeezeDet1024x1024/train/model.ckpt-0'
reader = load_reader(your_path)

reader.debug_string()

用于返回包含以下内容的一个字符串:

variable name(变量名)

data type(数据类型)

tensor shape(张量类型)

它返回字符串的各元素间均用空格符' '分隔,你可以使用debug_string来创建一个变量名列表,如下所示:

In [53]:

all_var_descriptions = reader.debug_string().split()
var_names, var_shapes = all_var[::3], all_var[2::3]
print(var_names[:4])
print(var_shapes[:4])

输出:

['iou', 'fire9/squeeze1x1/kernels', 'fire9/squeeze1x1/biases', 'fire9/expand3x3/kernels/Momentum']
['[10,36864]', '[1,1,512,64]', '[64]', '[3,3,64,256]']

但是,对于完成同样的任务,更好的方法是使用reader.get_variable_to_shape_map()

reader.get_variable_to_shape_map()

用于返回包含所有变量及其形状名称的字典,变量作为字典的Key,形状作为Value。

In [66]:

saved_shapes = reader.get_variable_to_shape_map()
print('fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels'])
fire9/squeeze1x1/kernels: [1, 1, 512, 64]
reader.has_tensor(var_name)

返回bool值

这是一种方便的方法,允许您检查ckeckpoint中是否存在相关的变量。

In [51]:

names_that_exit = {var_name: reader.has_tensor(var_name) for var_name in var_names[:10]}
for key in names_that_exit:
 print(key.decode()+':', names_that_exit[key])
fire8/squeeze1x1/kernels/Momentum: True
fire9/expand3x3/kernels: True
iou: True
fire9/expand3x3/biases: True
fire9/expand1x1/kernels: True
fire9/expand3x3/kernels/Momentum: True
fire9/expand1x1/biases/Momentum: True
fire9/squeeze1x1/biases: True
fire9/expand1x1/kernels/Momentum: True
fire9/squeeze1x1/kernels: True
reader.get_tensor(tensor_name)

返回包含检查点的张量值的NumPy数组

正常使用方法是先恢复一个张量,然后用恢复的张量初始化你自己的变量:

In [60]:

def recover_var(reader, var_name):
 recovered_var = 'var to be recovered'
 try:
  recovered_var = reader.get_tensor(var_name)
 except:
  assert reader.has_tensor(var_name),\
  "{} variable doesn't exist in the check point. Please check the variable name".format(var_name)
 return recovered_var

In [67]:

checkpoint_var = recover_var(reader, 'conv1/kernels')
print ("Recovered variable has the following shape: \n", checkpoint_var.shape)
new_var = tf.Variable(initial_value=checkpoint_var, name="new_conv1")
print ("New variable will be initialized with recovered values and the following shape: \n", new_var.get_shape())
Recovered variable has the following shape: 
(3, 3, 3, 64)
New variable will be initialized with recovered values and the following shape: 
(3, 3, 3, 64)

以上这篇tensorflow从ckpt和从.pb文件读取变量的值方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python程序中操作文件之flush()方法的使用教程
May 24 Python
Python程序中设置HTTP代理
Nov 06 Python
python爬虫获取淘宝天猫商品详细参数
Jun 23 Python
python中将一个全部为int的list 转化为str的list方法
Apr 09 Python
Python判断一个文件夹内哪些文件是图片的实例
Dec 07 Python
解决Python pandas plot输出图形中显示中文乱码问题
Dec 12 Python
python中时间模块的基本使用教程
May 14 Python
pyQT5 实现窗体之间传值的示例
Jun 20 Python
python调用c++返回带成员指针的类指针实例
Dec 12 Python
将python字符串转化成长表达式的函数eval实例
May 11 Python
Python中bisect的用法及示例详解
Jul 20 Python
BeautifulSoup获取指定class样式的div的实现
Dec 07 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
You might like
水质对咖图啡风味的影响具体有哪些
2021/03/03 冲泡冲煮
PHP如何得到当前页和上一页的地址?
2006/11/27 PHP
zen_cart实现支付前生成订单的方法
2016/05/06 PHP
javascript document.images实例
2008/05/27 Javascript
javascript cookie操作类的实现代码小结附使用方法
2010/06/02 Javascript
js实现网站首页图片滚动显示
2013/02/04 Javascript
jQuery中serializeArray()与serialize()的区别实例分析
2015/12/09 Javascript
深入探究JavaScript中for循环的效率问题及相关优化
2016/03/13 Javascript
第二篇Bootstrap起步
2016/06/21 Javascript
JS触摸屏网页版仿app弹窗型滚动列表选择器/日期选择器
2016/10/30 Javascript
前端构建工具之gulp的语法教程
2017/06/12 Javascript
详解ajax的data参数错误导致页面崩溃
2018/04/30 Javascript
Nodejs实现爬虫抓取数据实例解析
2018/07/05 NodeJs
vue elementUI 表单校验的实现代码(多层嵌套)
2019/11/06 Javascript
[01:03:18]DOTA2-DPC中国联赛 正赛 RNG vs Dynasty BO3 第一场 1月29日
2021/03/11 DOTA
详解python 拆包可迭代数据如tuple, list
2017/12/29 Python
python中append实例用法总结
2019/07/30 Python
python 字典 setdefault()和get()方法比较详解
2019/08/07 Python
Python for循环及基础用法详解
2019/11/08 Python
Python3中对json格式数据的分析处理
2021/01/28 Python
美国设计师精美珠宝购物网:Netaya
2016/08/28 全球购物
Dockers鞋官网:Dockers Shoes
2018/11/13 全球购物
美赞臣新加坡官方旗舰店:Enfagrow新加坡
2019/05/15 全球购物
会计电算化专业毕业生求职信范文
2013/12/10 职场文书
计算机通信工程专业毕业生推荐信
2013/12/24 职场文书
优秀小学生家长评语
2014/01/30 职场文书
2014年基层党组织公开承诺书
2014/03/29 职场文书
倡导文明标语
2014/06/16 职场文书
2014年小学少先队工作总结
2014/12/18 职场文书
毕业实习感受与体会
2015/05/26 职场文书
浅谈Python中的函数(def)及参数传递操作
2021/05/25 Python
浅谈MySQL表空间回收的正确姿势
2021/10/05 MySQL
MySQL sql模式设置引起的问题
2022/05/15 MySQL
Nginx利用Logrotate实现日志分割
2022/05/20 Servers
Win11 KB5015814遇安装失败 影响开始菜单性能解决方法
2022/07/15 数码科技
redis lua限流算法实现示例
2022/07/15 Redis