Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解


Posted in Python onFebruary 11, 2020

一、保存:

graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量

参数2:GraphDef 对象,它描述了计算网络

参数3:Graph图中需要输出的节点的名称的列表

返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:

constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())

需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。

二、恢复:

恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:

graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )

三、代码:

import tensorflow as tf 
from tensorflow.python.framework import graph_util
 
pbName = 'graphA.pb'
def graphCreate() :
  with tf.Session() as sess :
    var1 = tf.placeholder ( tf.int32 , name='var1' ) 
    var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名,该操作返回的张量名是在
                       #'var2'后面:0 ,即var2:0 是返回的张量名,也就是说变量
                       # var2的名称是'var2:0'
    var3 = tf.Variable( 30 , name='var3' )
    var4 = tf.Variable( 40 , name='var4' )
    var4op = tf.assign( var4 , 1000 , name = 'var4op1' )
    sum = tf.Variable( 4, name='sum' )
    sum = tf.add ( var1 , var2, name = 'var1_var2' ) 
    sum = tf.add( sum , var3 , name='sum_var3' )
    sumOps = tf.add( sum , var4 , name='sum_operation' )
    oper = tf.get_default_graph().get_operations()
    with open( 'operation.csv','wt' ) as f:
      s = 'name,type,output\n'
      f.write( s ) 
      for o in oper:
        s = o.name
        s += ','+ o.type 
        inp = o.inputs
        oup = o.outputs
        for iip in inp :
          s #s += ','+ str(iip)
        for iop in oup :
          s += ',' + str(iop)
        s += '\n'
        f.write( s ) 
         
      for var in tf.global_variables():
        print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果,
                        #张量的名字使用操作名加:0来表示
    init = tf.global_variables_initializer()
    sess.run( init )
    sess.run( var4op )
    print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )
 
    constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
    with open( pbName, mode='wb') as f:
      f.write(constant_graph.SerializeToString())
 
def graphGet() :
  print("start get:" )
  with tf.Graph().as_default():
    graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )
    with tf.Session() as sess :
      init = tf.global_variables_initializer()
      sess.run(init)
      v1 = sess.graph.get_tensor_by_name('var1:0' )
      v2 = sess.graph.get_tensor_by_name('var2:0' )
      v3 = sess.graph.get_tensor_by_name('var3:0' )
      v4 = sess.graph.get_tensor_by_name('var4:0' )
      
      sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
      print('sumTensor is : ' , sumTensor )
      print( sess.run( sumTensor , feed_dict={v1:1} ) ) 
  
graphCreate()
graphGet()

四、保存pb函数代码里的操作名称/类型/返回的张量:

operation name operation type output
var1 Placeholder Tensor("var1:0" dtype=int32)
var2/initial_value Const Tensor("var2/initial_value:0" shape=() dtype=int32)
var2 VariableV2 Tensor("var2:0" shape=() dtype=int32_ref)
var2/Assign Assign Tensor("var2/Assign:0" shape=() dtype=int32_ref)
var2/read Identity Tensor("var2/read:0" shape=() dtype=int32)
var3/initial_value Const Tensor("var3/initial_value:0" shape=() dtype=int32)
var3 VariableV2 Tensor("var3:0" shape=() dtype=int32_ref)
var3/Assign Assign Tensor("var3/Assign:0" shape=() dtype=int32_ref)
var3/read Identity Tensor("var3/read:0" shape=() dtype=int32)
var4/initial_value Const Tensor("var4/initial_value:0" shape=() dtype=int32)
var4 VariableV2 Tensor("var4:0" shape=() dtype=int32_ref)
var4/Assign Assign Tensor("var4/Assign:0" shape=() dtype=int32_ref)
var4/read Identity Tensor("var4/read:0" shape=() dtype=int32)
var4op1/value Const Tensor("var4op1/value:0" shape=() dtype=int32)
var4op1 Assign Tensor("var4op1:0" shape=() dtype=int32_ref)
sum/initial_value Const Tensor("sum/initial_value:0" shape=() dtype=int32)
sum VariableV2 Tensor("sum:0" shape=() dtype=int32_ref)
sum/Assign Assign Tensor("sum/Assign:0" shape=() dtype=int32_ref)
sum/read Identity Tensor("sum/read:0" shape=() dtype=int32)
var1_var2 Add Tensor("var1_var2:0" dtype=int32)
sum_var3 Add Tensor("sum_var3:0" dtype=int32)
sum_operation Add Tensor("sum_operation:0" dtype=int32)

以上这篇Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python类属性的延迟计算
Oct 22 Python
对pandas的dataframe绘图并保存的实现方法
Aug 05 Python
Python使用getpass库读取密码的示例
Oct 10 Python
Tornado 多进程实现分析详解
Jan 12 Python
Django模型序列化返回自然主键值示例代码
Jun 12 Python
python scrapy爬虫代码及填坑
Aug 12 Python
基于Python2、Python3中reload()的不同用法介绍
Aug 12 Python
python 列表、字典和集合的添加和删除操作
Dec 16 Python
python装饰器代替set get方法实例
Dec 19 Python
JupyterNotebook 输出窗口的显示效果调整方法
Apr 13 Python
Python常用外部指令执行代码实例
Nov 05 Python
Python爬虫入门案例之爬取去哪儿旅游景点攻略以及可视化分析
Oct 16 Python
TensorFlow:将ckpt文件固化成pb文件教程
Feb 11 #Python
TensorFlow获取加载模型中的全部张量名称代码
Feb 11 #Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 #Python
python使用正则表达式去除中文文本多余空格,保留英文之间空格方法详解
Feb 11 #Python
python 函数中的参数类型
Feb 11 #Python
python正则过滤字母、中文、数字及特殊字符方法详解
Feb 11 #Python
python3正则模块re的使用方法详解
Feb 11 #Python
You might like
PHP配置文件中最常用四个ini函数
2007/03/19 PHP
详解php中反射的应用
2016/03/15 PHP
php设计模式之单例模式代码
2016/06/11 PHP
PHP preg_match实现正则表达式匹配功能【输出是否匹配及匹配值】
2017/07/19 PHP
PHP文件上传小程序 适合初学者学习!
2019/05/23 PHP
一个可拖拽列宽表格实例演示
2012/11/26 Javascript
在JavaScript中使用开平方根的sqrt()方法
2015/06/15 Javascript
Ajax中解析Json的两种方法对比分析
2015/06/25 Javascript
实现JavaScript的组成----BOM和DOM详解
2016/05/18 Javascript
indexedDB bootstrap angularjs之 MVC DOMO (应用示例)
2016/06/20 Javascript
jQuery用FormData实现文件上传的方法
2016/11/21 Javascript
微信小程序 侧滑删除(左滑删除)
2017/05/23 Javascript
深入理解ES7的async/await的用法
2017/09/09 Javascript
关于Google发布的JavaScript代码规范你要知道哪些
2018/04/04 Javascript
vue实现div拖拽互换位置
2020/07/29 Javascript
js获取对象,数组所有属性键值(key)和对应值(value)的方法示例
2019/06/19 Javascript
JavaScript console的使用方法实例分析
2020/04/28 Javascript
python模块restful使用方法实例
2013/12/10 Python
使用Python脚本对Linux服务器进行监控的教程
2015/04/02 Python
python实现的简单猜数字游戏
2015/04/04 Python
python web框架学习笔记
2016/05/03 Python
python 巧用正则寻找字符串中的特定字符的位置方法
2018/05/02 Python
详解Python用三种方式统计词频的方法
2019/07/29 Python
更新升级python和pip版本后不生效的问题解决
2020/04/17 Python
Tensorflow加载Vgg预训练模型操作
2020/05/26 Python
paramiko使用tail实时获取服务器的日志输出详解
2020/12/06 Python
香港艺人陈冠希创办的潮流品牌:JUICESTORE
2021/03/04 全球购物
应用心理学个人求职信范文
2013/12/11 职场文书
个人查摆剖析材料
2014/02/04 职场文书
酒店值班经理的工作职责范本
2014/02/18 职场文书
上课看小说检讨书
2014/02/22 职场文书
抽样调查项目计划书
2014/04/24 职场文书
检讨书范文300字
2015/01/28 职场文书
2015年音乐教学工作总结
2015/07/22 职场文书
干货:如何写好观后感 !
2019/05/21 职场文书
Vue实现下拉加载更多
2021/05/09 Vue.js