TensorFlow固化模型的实现操作


Posted in Python onMay 26, 2020

前言

TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数、graph同时保存呢?

生成模型

主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用。另一种是把变量转成常量之后写入PB文件中。我们简单的介绍下freeze_graph方法。

freeze_graph

这种方法我们需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代码如下:

with tf.Session() as sess:
 saver = tf.train.Saver()
 saver.save(session, "model.ckpt")
 tf.train.write_graph(session.graph_def, '', 'graph.pb')

然后使用TensorFlow源码中的freeze_graph工具进行固化操作:

首先需要build freeze_graph 工具( 需要 bazel ):

bazel build tensorflow/python/tools:freeze_graph

然后使用这个工具进行固化(/path/to/表示文件路径):

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants

其实在TensorFlow中传统的保存模型方式是保存常量以及graph的,而我们的权重主要是变量,如果我们把训练好的权重变成常量之后再保存成PB文件,这样确实可以保存权重,就是方法有点繁琐,需要一个一个调用eval方法获取值之后赋值,再构建一个graph,把W和b赋值给新的graph。

牛逼的Google为了方便大家使用,编写了一个方法供我们快速的转换并保存。

首先我们需要引入这个方法

from tensorflow.python.framework.graph_util import convert_variables_to_constants

在想要保存的地方加入如下代码,把变量转换成常量

output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])

这里参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层代码是这样的:)

with tf.name_scope('output'):
 w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/weight', w_out)
 b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/biases', b_out)
 out = tf.add(tf.matmul(dense2, w_out), b_out)
 out = tf.nn.softmax(out)
 predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')

由于我们采用了name_scope所以我们在predict之前需要加上output/

生成文件

with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())

第一个参数是文件路径,第二个是指文件操作的模式,这里指的是以二进制的方式写入文件。

运行代码,系统会生成一个PB文件,接下来我们要测试下这个模型是否能够正常的读取、运行。

测试模型

在Python环境下,我们首先需要加载这个模型,代码如下:

with open('./model/rounded_graph.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 output = tf.import_graph_def(graph_def,
     input_map={'inputs/X:0': newInput_X},
     return_elements=['output/predict:0'])

由于我们原本的网络输入值是一个placeholder,这里为了方便输入我们也先定义一个新的placeholder:

newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")

在input_map的参数填入新的placeholder。

在调用我们的网络的时候直接用这个新的placeholder接收数据,如:

text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})

然后就是运行我们的网络,看是否可以运行吧。

以上这篇TensorFlow固化模型的实现操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用Python读取文件的四种不同方法比对
May 18 Python
Python中import机制详解
Nov 14 Python
使用pandas读取csv文件的指定列方法
Apr 21 Python
Python实现识别图片内容的方法分析
Jul 11 Python
Python中flatten( )函数及函数用法详解
Nov 02 Python
Python3 无重复字符的最长子串的实现
Oct 08 Python
Python 切分数组实例解析
Nov 07 Python
Python如何实现小程序 无限求和平均
Feb 18 Python
使用pth文件添加Python环境变量方式
May 26 Python
Python用类实现扑克牌发牌的示例代码
Jun 01 Python
Django全局启用登陆验证login_required的方法
Jun 02 Python
一文带你了解Python 四种常见基础爬虫方法介绍
Dec 04 Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
You might like
一拳超人中怪人协会钦定! S级别最强四人!
2020/03/02 日漫
php 判断访客是否为搜索引擎蜘蛛的函数代码
2011/07/29 PHP
php实现获取文件mime类型的方法
2015/02/11 PHP
锋利的jQuery 要点归纳(二) jQuery中的DOM操作(下)
2010/03/23 Javascript
一些实用的jQuery代码片段收集
2011/07/12 Javascript
一个不错的字符串转码解码函数(自写)
2014/07/31 Javascript
JQuery节点元素属性操作方法
2015/06/11 Javascript
针对初学者的jQuery入门指南
2015/08/15 Javascript
Bootstrap每天必学之基础排版
2015/11/20 Javascript
javascript 数组的定义和数组的长度
2016/06/07 Javascript
AngularJS Ajax详解及示例代码
2016/08/17 Javascript
jQuery无缝轮播图代码
2016/12/22 Javascript
Vue中的ref作用详解(实现DOM的联动操作)
2017/08/21 Javascript
element-ui upload组件多文件上传的示例代码
2018/10/17 Javascript
Vue项目中使用better-scroll实现一个轮播图自动播放功能
2018/12/03 Javascript
Nodejs环境实现socket通信过程解析
2020/07/03 NodeJs
如何基于jQuery实现五角星评分
2020/09/02 jQuery
vue使用echarts图表自适应的几种解决方案
2020/12/04 Vue.js
python冒泡排序算法的实现代码
2013/11/21 Python
用pickle存储Python的原生对象方法
2017/04/28 Python
Django+Ajax+jQuery实现网页动态更新的实例
2018/05/28 Python
[原创]Python入门教程3. 列表基本操作【定义、运算、常用函数】
2018/10/30 Python
Python+OpenCV实现图像融合的原理及代码
2018/12/03 Python
python实现微信自动回复及批量添加好友功能
2019/07/03 Python
使用Python代码实现Linux中的ls遍历目录命令的实例代码
2019/09/07 Python
利用Python产生加密表和解密表的实现方法
2019/10/15 Python
python3实现往mysql中插入datetime类型的数据
2020/03/02 Python
零基础学python应该从哪里入手
2020/08/11 Python
西班牙香水和化妆品购物网站:Arenal Perfumerías
2019/03/01 全球购物
Linden Leaves官网:新西兰纯净护肤品
2020/12/20 全球购物
高中毕业自我鉴定
2013/12/19 职场文书
《愚公移山》教学反思
2014/02/20 职场文书
美术教师岗位职责
2014/03/18 职场文书
弘扬职业精神演讲稿
2014/03/20 职场文书
出生公证委托书
2014/04/03 职场文书
植树节活动总结
2014/04/30 职场文书