将TensorFlow的模型网络导出为单个文件的方法


Posted in Python onApril 23, 2018

有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。

我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

我们可以采用以下方式冻结权重并保存网络:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

当恢复网络时,可以使用如下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

输出结果为:

[array([[ 7.],
       [ 8.]], dtype=float32)]

可以看到之前的权重确实保存了下来!!

问题来了,我们的网络需要能有一个输入自定义数据的接口啊!不然这玩意有什么用。。别急,当然有办法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

用上述代码重新保存网络至graph.pb,这次我们有了一个输入placeholder,下面来看看怎么恢复网络并输入自定义数据。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

输出结果为:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到结果没有问题,当然在input_map那里可以替换为新的自定义的placeholder,如下所示:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

看看输出,同样没有问题。

[array([[ 11.],
       [ 12.]], dtype=float32)]

另外需要说明的一点是,在利用tf.train.write_graph写网络架构的时候,如果令as_text=True了,则在导入网络的时候,需要做一点小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

参考资料

Is there an example on how to generate protobuf files holding trained Tensorflow graphs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python保存字符串到文件的方法
Jul 01 Python
举例讲解Python中的身份运算符的使用方法
Oct 13 Python
基于python中staticmethod和classmethod的区别(详解)
Oct 24 Python
python实现分页效果
Oct 25 Python
如何通过python画loss曲线的方法
Jun 26 Python
关于Python中定制类的比较运算实例
Dec 19 Python
Python3爬虫中Selenium的用法详解
Jul 10 Python
python判断元素是否存在的实例方法
Sep 24 Python
pandas处理csv文件的方法步骤
Oct 16 Python
python实现三阶魔方还原的示例代码
Apr 28 Python
Python上下文管理器Content Manager
Jun 26 Python
python中的getter与setter你了解吗
Mar 24 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 #Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
You might like
PHP统计数值数组中出现频率最多的10个数字的方法
2015/04/20 PHP
wamp服务器访问php非常缓慢的解决过程
2015/07/01 PHP
Yii调试查看执行SQL语句的方法
2016/07/15 PHP
php+jQuery实现的三级导航栏下拉菜单显示效果
2017/08/10 PHP
window.onload 加载完毕的问题及解决方案(上)
2009/07/09 Javascript
Jquery AJAX 框架的使用方法
2009/11/03 Javascript
页面加载完毕后滚动条自动滚动一定位置
2014/02/20 Javascript
Internet Explorer 11 浏览器介绍:别叫我IE
2014/09/28 Javascript
JavaScript图像延迟加载库Echo.js
2016/04/05 Javascript
浅析jquery与checkbox的checked属性的问题
2016/04/27 Javascript
JavaScript代码性能优化总结(推荐)
2016/05/16 Javascript
jQuery检查元素存在性(推荐)
2016/09/17 Javascript
js实现类bootstrap模态框动画
2017/02/07 Javascript
js实现控制文件拖拽并获取拖拽内容功能
2018/02/17 Javascript
vue动态改变背景图片demo分享
2018/09/13 Javascript
Vue基于vuex、axios拦截器实现loading效果及axios的安装配置
2019/04/26 Javascript
vue+django实现一对一聊天功能的实例代码
2019/07/17 Javascript
Vue数据双向绑定底层实现原理
2019/11/22 Javascript
Python循环语句中else的用法总结
2016/09/11 Python
Python自动化运维_文件内容差异对比分析
2017/12/13 Python
用TensorFlow实现戴明回归算法的示例
2018/05/02 Python
python3通过selenium爬虫获取到dj商品的实例代码
2019/04/25 Python
Python语言进阶知识点总结
2019/05/28 Python
django数据模型(Model)的字段类型解析
2019/12/25 Python
Python 多线程共享变量的实现示例
2020/04/17 Python
python爬取抖音视频的实例分析
2021/01/19 Python
html特殊符号示例 html特殊字符编码对照表
2014/01/14 HTML / CSS
玩具反斗城西班牙网上商城:ToysRUs西班牙
2017/01/19 全球购物
英国度假别墅预订:Sykes Cottages
2017/06/12 全球购物
科颜氏印度官网:Kiehl’s印度
2021/02/20 全球购物
说说你所熟悉或听说过的j2ee中的几种常用模式?及对设计模式的一些看法
2012/05/24 面试题
node中使用shell脚本的方法步骤
2021/03/23 Javascript
大专自我鉴定范文
2013/10/01 职场文书
乡镇领导班子四风整顿行动工作汇报
2014/10/25 职场文书
鉴史问廉观后感
2015/06/10 职场文书
Python代码实现双链表
2022/05/25 Python