将TensorFlow的模型网络导出为单个文件的方法


Posted in Python onApril 23, 2018

有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。

我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

我们可以采用以下方式冻结权重并保存网络:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

当恢复网络时,可以使用如下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

输出结果为:

[array([[ 7.],
       [ 8.]], dtype=float32)]

可以看到之前的权重确实保存了下来!!

问题来了,我们的网络需要能有一个输入自定义数据的接口啊!不然这玩意有什么用。。别急,当然有办法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

用上述代码重新保存网络至graph.pb,这次我们有了一个输入placeholder,下面来看看怎么恢复网络并输入自定义数据。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

输出结果为:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到结果没有问题,当然在input_map那里可以替换为新的自定义的placeholder,如下所示:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

看看输出,同样没有问题。

[array([[ 11.],
       [ 12.]], dtype=float32)]

另外需要说明的一点是,在利用tf.train.write_graph写网络架构的时候,如果令as_text=True了,则在导入网络的时候,需要做一点小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

参考资料

Is there an example on how to generate protobuf files holding trained Tensorflow graphs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用urllib模块开发的多线程豆瓣小站mp3下载器
Jan 16 Python
分析在Python中何种情况下需要使用断言
Apr 01 Python
详解Python的Django框架中inclusion_tag的使用
Jul 21 Python
总结网络IO模型与select模型的Python实例讲解
Jun 27 Python
Python装饰器知识点补充
May 28 Python
使用Python 正则匹配两个特定字符之间的字符方法
Dec 24 Python
解析pip安装第三方库但PyCharm中却无法识别的问题及PyCharm安装第三方库的方法教程
Mar 10 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 Python
Python实现清理微信僵尸粉功能示例【基于itchat模块】
May 29 Python
python opencv通过按键采集图片源码
May 20 Python
Python爬虫之自动爬取某车之家各车销售数据
Jun 02 Python
Python可视化动图组件ipyvizzu绘制惊艳的可视化动图
Apr 21 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 #Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
You might like
PHP关于IE下的iframe跨域导致session丢失问题解决方法
2013/10/10 PHP
PHP图片处理之图片背景、画布操作
2014/11/19 PHP
php表单文件iframe异步上传实例讲解
2017/07/26 PHP
return false;和e.preventDefault();的区别
2010/07/11 Javascript
图片延迟加载的实现代码(模仿懒惰)
2013/03/29 Javascript
html5 canvas js(数字时钟)实例代码
2013/12/23 Javascript
jQuery实现鼠标可拖动调整表格列宽度
2014/05/26 Javascript
javascript实现输出指定行数正方形图案的方法
2015/08/03 Javascript
微信小程序 picker 组件详解及简单实例
2017/01/10 Javascript
原生js轮播(仿慕课网)
2017/02/15 Javascript
微信小程序开发之从相册获取图片 使用相机拍照 本地图片上传
2017/04/18 Javascript
vue组件实现可搜索下拉框扩展
2020/10/23 Javascript
BootStrap中的模态框(modal,弹出层)功能示例代码
2018/11/02 Javascript
js中this的指向问题归纳总结
2018/11/28 Javascript
elementUI多选框反选的实现代码
2019/04/03 Javascript
bootstrap-table formatter 使用vue组件的方法
2019/05/09 Javascript
仿iPhone通讯录制作小程序自定义选择组件的实现
2019/05/23 Javascript
详解一些适用于Node.js的命名约定
2019/12/08 Javascript
[15:09]DOTA2国际邀请赛采访专栏:Loda
2013/08/06 DOTA
Python实现获取某天是某个月中的第几周
2015/02/11 Python
用Python展示动态规则法用以解决重叠子问题的示例
2015/04/02 Python
解决nohup重定向python输出到文件不成功的问题
2018/05/11 Python
Python线程池模块ThreadPoolExecutor用法分析
2018/12/28 Python
如何基于python生成list的所有的子集
2019/11/11 Python
如何使用python记录室友的抖音在线时间
2020/06/29 Python
Python的3种运行方式:命令行窗口、Python解释器、IDLE的实现
2020/10/10 Python
css3 盒模型以及box-sizing属性全面了解
2016/09/20 HTML / CSS
css 如何让背景图片拉伸填充避免重复显示
2013/07/11 HTML / CSS
详解HTML5中的拖放事件(Drag 和 drop)
2016/11/14 HTML / CSS
英国当代时尚和街头服饰店:18montrose
2018/12/15 全球购物
高三自我评价
2014/02/01 职场文书
优良学风班申请材料
2014/02/13 职场文书
爱国主义演讲稿
2014/05/07 职场文书
Axios取消重复请求的方法实例详解
2021/06/15 Javascript
python自动计算图像数据集的RGB均值
2021/06/18 Python
AudioContext 实现音频可视化(web技术分享)
2022/02/24 Javascript