TensorFlow固化模型的实现操作


Posted in Python onMay 26, 2020

前言

TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数、graph同时保存呢?

生成模型

主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用。另一种是把变量转成常量之后写入PB文件中。我们简单的介绍下freeze_graph方法。

freeze_graph

这种方法我们需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代码如下:

with tf.Session() as sess:
 saver = tf.train.Saver()
 saver.save(session, "model.ckpt")
 tf.train.write_graph(session.graph_def, '', 'graph.pb')

然后使用TensorFlow源码中的freeze_graph工具进行固化操作:

首先需要build freeze_graph 工具( 需要 bazel ):

bazel build tensorflow/python/tools:freeze_graph

然后使用这个工具进行固化(/path/to/表示文件路径):

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants

其实在TensorFlow中传统的保存模型方式是保存常量以及graph的,而我们的权重主要是变量,如果我们把训练好的权重变成常量之后再保存成PB文件,这样确实可以保存权重,就是方法有点繁琐,需要一个一个调用eval方法获取值之后赋值,再构建一个graph,把W和b赋值给新的graph。

牛逼的Google为了方便大家使用,编写了一个方法供我们快速的转换并保存。

首先我们需要引入这个方法

from tensorflow.python.framework.graph_util import convert_variables_to_constants

在想要保存的地方加入如下代码,把变量转换成常量

output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])

这里参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层代码是这样的:)

with tf.name_scope('output'):
 w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/weight', w_out)
 b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
 tf.summary.histogram('output/biases', b_out)
 out = tf.add(tf.matmul(dense2, w_out), b_out)
 out = tf.nn.softmax(out)
 predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')

由于我们采用了name_scope所以我们在predict之前需要加上output/

生成文件

with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())

第一个参数是文件路径,第二个是指文件操作的模式,这里指的是以二进制的方式写入文件。

运行代码,系统会生成一个PB文件,接下来我们要测试下这个模型是否能够正常的读取、运行。

测试模型

在Python环境下,我们首先需要加载这个模型,代码如下:

with open('./model/rounded_graph.pb', 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 output = tf.import_graph_def(graph_def,
     input_map={'inputs/X:0': newInput_X},
     return_elements=['output/predict:0'])

由于我们原本的网络输入值是一个placeholder,这里为了方便输入我们也先定义一个新的placeholder:

newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")

在input_map的参数填入新的placeholder。

在调用我们的网络的时候直接用这个新的placeholder接收数据,如:

text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})

然后就是运行我们的网络,看是否可以运行吧。

以上这篇TensorFlow固化模型的实现操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 实现文件的递归拷贝实现代码
Aug 02 Python
python操作ie登陆土豆网的方法
May 09 Python
使用py2exe在Windows下将Python程序转为exe文件
Mar 04 Python
python实现下载文件的三种方法
Feb 09 Python
浅谈机器学习需要的了解的十大算法
Dec 15 Python
学习Python3 Dlib19.7进行人脸面部识别
Jan 24 Python
Python之csv文件从MySQL数据库导入导出的方法
Jun 21 Python
Python wxpython模块响应鼠标拖动事件操作示例
Aug 23 Python
python绘制多个曲线的折线图
Mar 23 Python
Python中安装easy_install的方法
Nov 18 Python
Python图像处理之图片文字识别功能(OCR)
Jul 30 Python
详解使用python爬取抖音app视频(appium可以操控手机)
Jan 26 Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
You might like
php基础知识:类与对象(5) static
2006/12/13 PHP
DedeCMS dede_channeltype表字段注释
2010/04/07 PHP
PHP中数组合并的两种方法及区别介绍
2012/09/14 PHP
Joomla数据库操作之JFactory::getDBO用法
2016/05/05 PHP
JavaScript的目的分析
2007/01/05 Javascript
jquery 查找iframe父级页面元素的实现代码
2011/08/28 Javascript
jquery Mobile入门—多页面切换示例学习
2013/01/08 Javascript
PHP开发者必须掌握的6个关键字
2014/04/14 Javascript
使用Jquery获取带特殊符号的ID 标签的方法
2014/04/30 Javascript
JS实现统计复选框选中个数并提示确定与取消的方法
2015/07/01 Javascript
jquery siblings获取同辈元素用法实例分析
2016/07/25 Javascript
原生JS实现的放大镜效果实例代码
2016/10/15 Javascript
JS闭包用法实例分析
2017/03/27 Javascript
vue-cli的webpack模板项目配置文件分析
2017/04/01 Javascript
Webpack执行命令参数详解
2017/06/17 Javascript
Vue2.0 vue-source jsonp 跨域请求
2017/08/04 Javascript
利用javascript如何随机生成一定位数的密码
2017/09/22 Javascript
nodejs的安装使用与npm的介绍
2019/09/11 NodeJs
layui动态渲染生成select的option值方法
2019/09/23 Javascript
解决Idea、WebStorm下使用Vue cli脚手架项目无法使用Webpack别名的问题
2019/10/11 Javascript
Python中的复制操作及copy模块中的浅拷贝与深拷贝方法
2016/07/02 Python
TensorFlow实现Batch Normalization
2018/03/08 Python
Python循环结构的应用场景详解
2019/07/11 Python
Django错误:TypeError at / 'bool' object is not callable解决
2019/08/16 Python
python绘制BA无标度网络示例代码
2019/11/21 Python
Python openpyxl 插入折线图实例
2020/04/17 Python
Python内置函数及功能简介汇总
2020/10/13 Python
部队学习十八大感言
2014/01/11 职场文书
学生打架检讨书大全
2014/01/23 职场文书
升旗仪式主持词
2014/03/19 职场文书
学生个人自我鉴定
2014/03/26 职场文书
嘉宾邀请函
2015/01/31 职场文书
大学生党员个人总结
2015/02/13 职场文书
2015小学教师年度工作总结
2015/05/12 职场文书
新郎父亲婚礼致辞
2015/07/27 职场文书
银行大堂经理培训心得体会
2016/01/09 职场文书