TensorFlow固化模型的实现操作


Posted in Python onMay 26, 2020

前言

TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数、graph同时保存呢?

生成模型

主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用。另一种是把变量转成常量之后写入PB文件中。我们简单的介绍下freeze_graph方法。

freeze_graph

这种方法我们需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代码如下:

with tf.Session() as sess:
 saver = tf.train.Saver()
 saver.save(session, "model.ckpt")
 tf.train.write_graph(session.graph_def, '', 'graph.pb')

然后使用TensorFlow源码中的freeze_graph工具进行固化操作:

首先需要build freeze_graph 工具( 需要 bazel ):

bazel build tensorflow/python/tools:freeze_graph

然后使用这个工具进行固化(/path/to/表示文件路径):

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants

其实在TensorFlow中传统的保存模型方式是保存常量以及graph的,而我们的权重主要是变量,如果我们把训练好的权重变成常量之后再保存成PB文件,这样确实可以保存权重,就是方法有点繁琐,需要一个一个调用eval方法获取值之后赋值,再构建一个graph,把W和b赋值给新的graph。

牛逼的Google为了方便大家使用,编写了一个方法供我们快速的转换并保存。

首先我们需要引入这个方法

from tensorflow.python.framework.graph_util import convert_variables_to_constants

在想要保存的地方加入如下代码,把变量转换成常量

output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])

这里参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层代码是这样的:)

with tf.name_scope('output'):
 w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/weight', w_out)
 b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/biases', b_out)
 out = tf.add(tf.matmul(dense2, w_out), b_out)
 out = tf.nn.softmax(out)
 predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')

由于我们采用了name_scope所以我们在predict之前需要加上output/

生成文件

with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())

第一个参数是文件路径,第二个是指文件操作的模式,这里指的是以二进制的方式写入文件。

运行代码,系统会生成一个PB文件,接下来我们要测试下这个模型是否能够正常的读取、运行。

测试模型

在Python环境下,我们首先需要加载这个模型,代码如下:

with open('./model/rounded_graph.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 output = tf.import_graph_def(graph_def,
     input_map={'inputs/X:0': newInput_X},
     return_elements=['output/predict:0'])

由于我们原本的网络输入值是一个placeholder,这里为了方便输入我们也先定义一个新的placeholder:

newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")

在input_map的参数填入新的placeholder。

在调用我们的网络的时候直接用这个新的placeholder接收数据,如:

text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})

然后就是运行我们的网络,看是否可以运行吧。

以上这篇TensorFlow固化模型的实现操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的random模块及加权随机算法的python实现方法
Jan 04 Python
Python for循环与range函数的使用详解
Mar 23 Python
详解PyCharm+QTDesigner+PyUIC使用教程
Jun 13 Python
Django 请求Request的具体使用方法
Nov 11 Python
Django 实现外键去除自动添加的后缀‘_id’
Nov 15 Python
Python continue语句实例用法
Feb 06 Python
Django 用户登陆访问限制实例 @login_required
May 13 Python
Python爬虫获取豆瓣电影并写入excel
Jul 31 Python
利用Python如何制作贪吃蛇及AI版贪吃蛇详解
Aug 24 Python
python代数式括号有效性检验示例代码
Oct 04 Python
通过Python pyecharts输出保存图片代码实例
Nov 25 Python
浅谈matplotlib默认字体设置探索
Feb 03 Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
You might like
PHP观察者模式示例【Laravel框架中有用到】
2018/06/15 PHP
php多进程并发编程防止出现僵尸进程的方法分析
2020/02/28 PHP
自己写的兼容ie和ff的在线文本编辑器类似ewebeditor
2012/12/12 Javascript
JavaScript获取路径设计源码
2014/05/22 Javascript
JavaScript实现select添加option
2015/07/03 Javascript
jQuery+css实现的tab切换标签(兼容各浏览器)
2016/01/28 Javascript
全面解析Bootstrap表单样式的使用
2016/09/09 Javascript
Javascript从数组中随机取出不同元素的两种方法
2016/09/22 Javascript
js 中文汉字转Unicode、Unicode转中文汉字、ASCII转换Unicode、Unicode转换ASCII、中文转换
2016/12/06 Javascript
使用vue2实现购物车和地址选配功能
2018/03/29 Javascript
Vue.set()动态的新增与修改数据,触发视图更新的方法
2018/09/15 Javascript
vue中如何实现后台管理系统的权限控制的方法示例
2018/09/19 Javascript
JS实现横向轮播图(中级版)
2020/01/18 Javascript
JavaScript实现横版菜单栏
2020/03/17 Javascript
[57:12]完美世界DOTA2联赛循环赛 Inki vs Matador BO2第一场 10.31
2020/11/02 DOTA
python网络编程示例(客户端与服务端)
2014/04/24 Python
Python中使用装饰器时需要注意的一些问题
2015/05/11 Python
python实现用于测试网站访问速率的方法
2015/05/26 Python
Python-嵌套列表list的全面解析
2016/06/08 Python
Python书单 不将就
2017/07/11 Python
简单了解Python中的几种函数
2017/11/03 Python
PyCharm+PySpark远程调试的环境配置的方法
2018/11/29 Python
对python的bytes类型数据split分割切片方法
2018/12/04 Python
漂亮的Django Markdown富文本app插件的实现
2019/01/02 Python
基于python cut和qcut的用法及区别详解
2019/11/22 Python
详解Python 实现 ZeroMQ 的三种基本工作模式
2020/03/24 Python
解决Django部署设置Debug=False时xadmin后台管理系统样式丢失
2020/04/07 Python
解决安装新版PyQt5、PyQT5-tool后打不开并Designer.exe提示no Qt platform plugin的问题
2020/04/24 Python
德国最大的服装、鞋子和配件在线商店之一:Outfits24
2019/07/23 全球购物
中介公司区域经理岗位职责范本
2014/03/02 职场文书
白莲教口号
2014/06/18 职场文书
作风建设剖析材料
2014/10/06 职场文书
2014幼儿园教育教学工作总结
2014/12/17 职场文书
2016年第32个教师节致辞
2015/11/26 职场文书
Ajax实现局部刷新的方法实例
2021/03/31 Javascript
利用Python第三方库实现预测NBA比赛结果
2021/06/21 Python