将TensorFlow的模型网络导出为单个文件的方法


Posted in Python onApril 23, 2018

有时候,我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。

我们知道,graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

我们可以采用以下方式冻结权重并保存网络:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

当恢复网络时,可以使用如下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

输出结果为:

[array([[ 7.],
       [ 8.]], dtype=float32)]

可以看到之前的权重确实保存了下来!!

问题来了,我们的网络需要能有一个输入自定义数据的接口啊!不然这玩意有什么用。。别急,当然有办法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

用上述代码重新保存网络至graph.pb,这次我们有了一个输入placeholder,下面来看看怎么恢复网络并输入自定义数据。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))

输出结果为:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到结果没有问题,当然在input_map那里可以替换为新的自定义的placeholder,如下所示:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))

看看输出,同样没有问题。

[array([[ 11.],
       [ 12.]], dtype=float32)]

另外需要说明的一点是,在利用tf.train.write_graph写网络架构的时候,如果令as_text=True了,则在导入网络的时候,需要做一点小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))

参考资料

Is there an example on how to generate protobuf files holding trained Tensorflow graphs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简要讲解Python编程中线程的创建与锁的使用
Feb 28 Python
python实现获取Ip归属地等信息
Aug 27 Python
python的pdb调试命令的命令整理及实例
Jul 12 Python
Python轻量级ORM框架Peewee访问sqlite数据库的方法详解
Jul 20 Python
pygame游戏之旅 按钮上添加文字的方法
Nov 21 Python
celery4+django2定时任务的实现代码
Dec 23 Python
详解Python中正则匹配TAB及空格的小技巧
Jul 26 Python
解决pycharm最左侧Tool Buttons显示不全的问题
Dec 17 Python
利用python实现逐步回归
Feb 24 Python
对python中arange()和linspace()的区别说明
May 03 Python
Python基础之元编程知识总结
May 23 Python
Python Django / Flask如何使用Elasticsearch
Apr 19 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 #Python
tensorflow 使用flags定义命令行参数的方法
Apr 23 #Python
Tensorflow之Saver的用法详解
Apr 23 #Python
python获取文件路径、文件名、后缀名的实例
Apr 23 #Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 #Python
Python基于whois模块简单识别网站域名及所有者的方法
Apr 23 #Python
Python实现自定义顺序、排列写入数据到Excel的方法
Apr 23 #Python
You might like
JS中encodeURIComponent函数用php解码的代码
2012/03/01 PHP
php二维数组排序方法(array_multisort usort)
2013/12/25 PHP
php使用cookie显示用户上次访问网站日期的方法
2015/01/26 PHP
跟着JQuery API学Jquery 之二 属性
2010/04/09 Javascript
js全选实现和判断是否有复选框选中的方法
2015/02/17 Javascript
AngularJS控制器controller正确的通信的方法
2016/01/25 Javascript
浅谈Javascript数组(推荐)
2016/05/17 Javascript
js基本算法:冒泡排序,二分查找的简单实例
2016/10/08 Javascript
JavaScript中访问id对象 属性的方式访问属性(实例代码)
2016/10/28 Javascript
jquery easyui validatebox remote的使用详解
2016/11/09 Javascript
Node.js利用Net模块实现多人命令行聊天室的方法
2016/12/23 Javascript
js仿淘宝评价评分功能
2017/02/28 Javascript
BootStrap入门学习第一篇
2017/08/28 Javascript
JavaScript实现焦点进入文本框内关闭输入法的核心代码
2017/09/20 Javascript
Node.js学习教程之HTTP/2服务器推送【译】
2017/10/31 Javascript
Angular5.1新功能分享
2017/12/21 Javascript
几个你不知道的技巧助你写出更优雅的vue.js代码
2018/06/11 Javascript
Vue中axios的封装(报错、鉴权、跳转、拦截、提示)
2019/08/20 Javascript
layui表格内放置图片,并点击放大的实例
2019/09/10 Javascript
es6 super关键字的理解与应用实例分析
2020/02/15 Javascript
[01:19:46]EG vs Secret 2019国际邀请赛淘汰赛 胜者组 BO3 第二场 8.21.mp4
2020/07/19 DOTA
Python生成密码库功能示例
2017/05/23 Python
python 简单搭建阻塞式单进程,多进程,多线程服务的实例
2017/11/01 Python
python实现感知器算法详解
2017/12/19 Python
python机器学习之随机森林(七)
2018/03/26 Python
Flask web开发处理POST请求实现(登录案例)
2018/07/26 Python
Django添加feeds功能的示例
2018/08/07 Python
pycharm运行程序时在Python console窗口中运行的方法
2018/12/03 Python
Python设计模式之原型模式实例详解
2019/01/18 Python
python实现邮件自动发送
2019/08/10 Python
使用CSS3配合IE滤镜实现渐变和投影的效果
2015/09/06 HTML / CSS
Gap英国官网:Gap UK
2018/07/18 全球购物
Tod’s英国官方网站:意大利奢华手工制作手袋和鞋履
2019/03/15 全球购物
印度第一网上礼品店:IGP.com
2020/02/06 全球购物
2014年环卫工作总结
2014/11/22 职场文书
Windows server 2012 R2 安装IIS服务器
2022/04/29 Servers