Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解


Posted in Python onFebruary 11, 2020

一、保存:

graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量

参数2:GraphDef 对象,它描述了计算网络

参数3:Graph图中需要输出的节点的名称的列表

返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:

constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())

需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。

二、恢复:

恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:

graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )

三、代码:

import tensorflow as tf 
from tensorflow.python.framework import graph_util
 
pbName = 'graphA.pb'
def graphCreate() :
  with tf.Session() as sess :
    var1 = tf.placeholder ( tf.int32 , name='var1' ) 
    var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名,该操作返回的张量名是在
                       #'var2'后面:0 ,即var2:0 是返回的张量名,也就是说变量
                       # var2的名称是'var2:0'
    var3 = tf.Variable( 30 , name='var3' )
    var4 = tf.Variable( 40 , name='var4' )
    var4op = tf.assign( var4 , 1000 , name = 'var4op1' )
    sum = tf.Variable( 4, name='sum' )
    sum = tf.add ( var1 , var2, name = 'var1_var2' ) 
    sum = tf.add( sum , var3 , name='sum_var3' )
    sumOps = tf.add( sum , var4 , name='sum_operation' )
    oper = tf.get_default_graph().get_operations()
    with open( 'operation.csv','wt' ) as f:
      s = 'name,type,output\n'
      f.write( s ) 
      for o in oper:
        s = o.name
        s += ','+ o.type 
        inp = o.inputs
        oup = o.outputs
        for iip in inp :
          s #s += ','+ str(iip)
        for iop in oup :
          s += ',' + str(iop)
        s += '\n'
        f.write( s ) 
         
      for var in tf.global_variables():
        print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果,
                        #张量的名字使用操作名加:0来表示
    init = tf.global_variables_initializer()
    sess.run( init )
    sess.run( var4op )
    print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )
 
    constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
    with open( pbName, mode='wb') as f:
      f.write(constant_graph.SerializeToString())
 
def graphGet() :
  print("start get:" )
  with tf.Graph().as_default():
    graph0 = tf.GraphDef()
    with open( pbName, mode='rb') as f:
      graph0.ParseFromString( f.read() )
      tf.import_graph_def( graph0 , name = '' )
    with tf.Session() as sess :
      init = tf.global_variables_initializer()
      sess.run(init)
      v1 = sess.graph.get_tensor_by_name('var1:0' )
      v2 = sess.graph.get_tensor_by_name('var2:0' )
      v3 = sess.graph.get_tensor_by_name('var3:0' )
      v4 = sess.graph.get_tensor_by_name('var4:0' )
      
      sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
      print('sumTensor is : ' , sumTensor )
      print( sess.run( sumTensor , feed_dict={v1:1} ) ) 
  
graphCreate()
graphGet()

四、保存pb函数代码里的操作名称/类型/返回的张量:

operation name operation type output
var1 Placeholder Tensor("var1:0" dtype=int32)
var2/initial_value Const Tensor("var2/initial_value:0" shape=() dtype=int32)
var2 VariableV2 Tensor("var2:0" shape=() dtype=int32_ref)
var2/Assign Assign Tensor("var2/Assign:0" shape=() dtype=int32_ref)
var2/read Identity Tensor("var2/read:0" shape=() dtype=int32)
var3/initial_value Const Tensor("var3/initial_value:0" shape=() dtype=int32)
var3 VariableV2 Tensor("var3:0" shape=() dtype=int32_ref)
var3/Assign Assign Tensor("var3/Assign:0" shape=() dtype=int32_ref)
var3/read Identity Tensor("var3/read:0" shape=() dtype=int32)
var4/initial_value Const Tensor("var4/initial_value:0" shape=() dtype=int32)
var4 VariableV2 Tensor("var4:0" shape=() dtype=int32_ref)
var4/Assign Assign Tensor("var4/Assign:0" shape=() dtype=int32_ref)
var4/read Identity Tensor("var4/read:0" shape=() dtype=int32)
var4op1/value Const Tensor("var4op1/value:0" shape=() dtype=int32)
var4op1 Assign Tensor("var4op1:0" shape=() dtype=int32_ref)
sum/initial_value Const Tensor("sum/initial_value:0" shape=() dtype=int32)
sum VariableV2 Tensor("sum:0" shape=() dtype=int32_ref)
sum/Assign Assign Tensor("sum/Assign:0" shape=() dtype=int32_ref)
sum/read Identity Tensor("sum/read:0" shape=() dtype=int32)
var1_var2 Add Tensor("var1_var2:0" dtype=int32)
sum_var3 Add Tensor("sum_var3:0" dtype=int32)
sum_operation Add Tensor("sum_operation:0" dtype=int32)

以上这篇Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用点操作符访问字典(dict)数据的方法
Mar 16 Python
详解Python的Django框架中的templates设置
May 11 Python
python爬虫获取新浪新闻教学
Dec 23 Python
提升Python程序性能的7个习惯
Apr 14 Python
PyQt4 treewidget 选择改变颜色,并设置可编辑的方法
Jun 17 Python
Python判断字符串是否xx开始或结尾的示例
Aug 08 Python
python 遍历pd.Series的index和value
Nov 26 Python
Python利用PyExecJS库执行JS函数的案例分析
Dec 18 Python
python操作cfg配置文件方式
Dec 22 Python
深入浅析Python 函数注解与匿名函数
Feb 24 Python
python Tkinter的简单入门教程
Apr 11 Python
Python语言内置数据类型
Feb 24 Python
TensorFlow:将ckpt文件固化成pb文件教程
Feb 11 #Python
TensorFlow获取加载模型中的全部张量名称代码
Feb 11 #Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 #Python
python使用正则表达式去除中文文本多余空格,保留英文之间空格方法详解
Feb 11 #Python
python 函数中的参数类型
Feb 11 #Python
python正则过滤字母、中文、数字及特殊字符方法详解
Feb 11 #Python
python3正则模块re的使用方法详解
Feb 11 #Python
You might like
PHP提示Deprecated: mysql_connect(): The mysql extension is deprecated的解决方法
2014/08/28 PHP
PHP移动文件指针ftell()、fseek()、rewind()函数总结
2014/11/18 PHP
ThinkPHP控制器里javascript代码不能执行的解决方法
2014/11/22 PHP
PHP的mysqli_select_db()函数讲解
2019/01/23 PHP
Javascript 构造函数,公有,私有特权和静态成员定义方法
2009/11/30 Javascript
JS返回上一页实例代码通过图片和按钮分别实现
2013/08/16 Javascript
JS 删除字符串最后一个字符的实现代码
2014/02/20 Javascript
Jquery中Event对象属性小结
2015/02/27 Javascript
浅谈Nodejs观察者模式
2015/10/13 NodeJs
JavaScript实现广告弹窗效果
2016/08/09 Javascript
vue.js实现仿原生ios时间选择组件实例代码
2016/12/21 Javascript
在百度搜索结果中去除掉一些网站的资料(通过js控制不让显示)
2017/05/02 Javascript
Angular实现响应式表单
2017/08/04 Javascript
在vue项目中集成graphql(vue-ApolloClient)
2018/09/08 Javascript
jQuery使用$.extend(true,object1, object2);实现深拷贝对象的方法分析
2019/03/06 jQuery
Vue CLI项目 axios模块前后端交互的使用(类似ajax提交)
2019/09/01 Javascript
javascript实现函数柯里化与反柯里化过程解析
2019/10/08 Javascript
原生JS实现拖拽功能
2020/12/16 Javascript
Python中optparser库用法实例详解
2018/01/26 Python
Python基于pandas实现json格式转换成dataframe的方法
2018/06/22 Python
Python高级特性切片(Slice)操作详解
2018/09/27 Python
对python 匹配字符串开头和结尾的方法详解
2018/10/27 Python
python实现自动获取IP并发送到邮箱
2018/12/26 Python
Python玩转PDF的各种骚操作
2019/05/06 Python
一行Python代码过滤标点符号等特殊字符
2019/08/12 Python
tensorflow 实现数据类型转换
2020/02/17 Python
20行Python代码实现一款永久免费PDF编辑工具的实现
2020/08/27 Python
波兰在线运动商店:YesSport
2020/07/23 全球购物
县委班子四风对照检查材料思想汇报
2014/09/29 职场文书
故意伤害罪辩护词
2015/05/21 职场文书
2015初中团委工作总结
2015/07/28 职场文书
军事理论课感想
2015/08/11 职场文书
体育部部长竞选稿
2015/11/21 职场文书
2019年教师节祝福语精选,给老师送上真诚的祝福
2019/09/09 职场文书
将图片保存到mysql数据库并展示在前端页面的实现代码
2021/05/02 MySQL
Redisson实现Redis分布式锁的几种方式
2021/08/07 Redis