将TensorFlow的模型网络导出为单个文件的方法


Posted in Python onApril 23, 2018

有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。

我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

我们可以采用以下方式冻结权重并保存网络:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

当恢复网络时,可以使用如下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

输出结果为:

[array([[ 7.],
       [ 8.]], dtype=float32)]

可以看到之前的权重确实保存了下来!!

问题来了,我们的网络需要能有一个输入自定义数据的接口啊!不然这玩意有什么用。。别急,当然有办法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

用上述代码重新保存网络至graph.pb,这次我们有了一个输入placeholder,下面来看看怎么恢复网络并输入自定义数据。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

输出结果为:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到结果没有问题,当然在input_map那里可以替换为新的自定义的placeholder,如下所示:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

看看输出,同样没有问题。

[array([[ 11.],
       [ 12.]], dtype=float32)]

另外需要说明的一点是,在利用tf.train.write_graph写网络架构的时候,如果令as_text=True了,则在导入网络的时候,需要做一点小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

参考资料

Is there an example on how to generate protobuf files holding trained Tensorflow graphs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中列表的一些基本操作知识汇总
May 20 Python
python发送HTTP请求的方法小结
Jul 08 Python
Python selenium如何设置等待时间
Sep 15 Python
python妙用之编码的转换详解
Apr 21 Python
浅谈Python中的zip()与*zip()函数详解
Feb 24 Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 Python
python 实现在txt指定行追加文本的方法
Apr 29 Python
Django框架的中的setting.py文件说明详解
Oct 15 Python
使用python3构建文件传输的方法
Feb 13 Python
详解Python Matplot中文显示完美解决方案
Mar 07 Python
如何使用Python 打印各种三角形
Jun 28 Python
Python中字符串List按照长度排序
Jul 01 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 #Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
You might like
php中随机显示图片的函数代码
2011/06/23 PHP
PHP判断变量是否为0的方法
2014/02/08 PHP
ThinkPHP3.1新特性之字段合法性检测详解
2014/06/19 PHP
php实现的简单日志写入函数
2015/03/31 PHP
PHP 自动加载的简单实现(推荐)
2016/08/12 PHP
jQuery的运行机制和设计理念分析
2011/04/05 Javascript
jquery实现简单的无缝滚动
2015/04/15 Javascript
jQuery ui autocomplete选择列表被Bootstrap模态窗遮挡的完美解决方法
2016/09/23 Javascript
详谈js对url进行编码和解码(三种方式的区别)
2017/08/16 Javascript
JS中appendChild追加子节点无效的解决方法
2018/10/14 Javascript
微信小程序实现日历功能
2018/11/27 Javascript
layui实现根据table数据判断按钮显示情况的方法
2019/09/26 Javascript
vue-resourc发起异步请求的方法
2020/02/11 Javascript
Element Popover 弹出框的使用示例
2020/07/26 Javascript
JavaScript数组常用的增删改查与其他属性详解
2020/10/13 Javascript
js数组的基本使用总结
2021/01/18 Javascript
[03:17]史诗级大片应援2018DOTA2国际邀请赛 致敬每一位坚守遗迹的勇士
2018/07/20 DOTA
Python生成不重复随机值的方法
2015/05/11 Python
Python易忽视知识点小结
2015/05/25 Python
Python编程产生非均匀随机数的几种方法代码分享
2017/12/13 Python
python3+dlib实现人脸识别和情绪分析
2018/04/21 Python
python实现守护进程、守护线程、守护非守护并行
2018/05/05 Python
pandas使用get_dummies进行one-hot编码的方法
2018/07/10 Python
django DRF图片路径问题的解决方法
2018/09/10 Python
python实现socket+threading处理多连接的方法
2019/07/23 Python
Python操作列表常用方法实例小结【创建、遍历、统计、切片等】
2019/10/25 Python
Python 实现简单的客户端认证
2020/07/29 Python
python代码实现图书管理系统
2020/11/30 Python
公司聘任书模板
2014/03/29 职场文书
倡议书格式
2014/04/14 职场文书
安全负责人任命书
2014/06/06 职场文书
销售类求职信
2014/06/13 职场文书
党的群众路线对照检查材料思想汇报(学校)
2014/10/04 职场文书
钓鱼岛事件感想
2015/08/11 职场文书
2016年“世界环境日”校园广播稿
2015/12/18 职场文书
HTML+JS实现在线朗读器
2022/02/15 Javascript