tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例


Posted in Python onJune 22, 2020

升级到tf 2.0后, 训练的模型想转成1.x版本的.pb模型, 但之前提供的通过ckpt转pb模型的方法都不可用(因为保存的ckpt不再有.meta)文件, 尝试了好久, 终于找到了一个方法可以迂回转到1.x版本的pb模型.

Note: 本方法首先有些要求需要满足:

可以拿的到模型的网络结构定义源码

网络结构里面的所有操作都是通过tf.keras完成的, 不能出现类似tf.nn 的tensorflow自己的操作符

tf2.0下保存的模型是.h5格式的,并且仅保存了weights, 即通过model.save_weights保存的模型.

在tf1.x的环境下, 将tf2.0保存的weights转为pb模型:

如果在tf2.0下保存的模型符合上述的三个定义, 那么这个.h5文件在1.x环境下其实是可以直接用的, 因为都是通过tf.keras高级封装了,2.0版本和1.x版本不存在特别大的区别,我自己的模型是可以直接用的.

import tensorflow as tf
import os
from nets.efficientNet import *
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# 这个代码网上说需要加上, 如果模型里有dropout , bn层的话, 我测试过加不加结果都一样, 保险起见还是加上吧
tf.keras.backend.set_learning_phase(0)

# 首先是定义你的模型, 这个需要和tf2.0下一毛一样
inputs = tf.keras.Input(shape=(224, 224, 3), name='modelInput')
outputs = yourModel(inputs, training=False)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.load_weights('save_weights.h5')
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
  """
  Freezes the state of a session into a pruned computation graph.

  Creates a new computation graph where variable nodes are replaced by
  constants taking their current value in the session. The new graph will be
  pruned so subgraphs that are not necessary to compute the requested
  outputs are removed.
  @param session The TensorFlow session to be frozen.
  @param keep_var_names A list of variable names that should not be frozen,
             or None to freeze all the variables in the graph.
  @param output_names Names of the relevant graph outputs.
  @param clear_devices Remove the device directives from the graph for better portability.
  @return The frozen graph definition.
  """
  from tensorflow.python.framework.graph_util import convert_variables_to_constants
  graph = session.graph
  with graph.as_default():
    freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
    output_names = output_names or []
    output_names += [v.op.name for v in tf.global_variables()]
    # Graph -> GraphDef ProtoBuf
    input_graph_def = graph.as_graph_def(add_shapes=True)
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""
    frozen_graph = convert_variables_to_constants(session, input_graph_def,
                           output_names, freeze_var_names)
    return frozen_graph

frozen_graph = freeze_session(tf.keras.backend.get_session(), output_names=[out.op.name for out in model.outputs])
tf.train.write_graph(frozen_graph, "model", "tf_model.pb", as_text=False)

运行成功后, 会在当前目录下生成一个model文件夹, 里面有生成的tf_model.pb文件, 至此, 我们就完成了将tf2.0下训练的模型转到tf1.x下的pb模型, 这样,就可以用这个pb模型做其它推理或者转tvm ncnn等模型转换工作.

这个转换的重点就是通过keras这个中间商来完成, 所以我们定义的模型就必须要满足这个中间商定义的条件

补充知识:tensorflow2.0降级及如何从别的版本升到2.0

代码实践《tensorflow实战GOOGLE深度学习框架》时,由于本机安装的tensorflow为2.0版本与配套书籍代码1.4的API不兼容,只得将tensorflow降级为1.4.0版本使用,降级方法如下

1 pip uninstall tensorflow

tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例

2 pip install tensorflow==1.14.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例

验证

import tensorflow as tf
print(tf.version)

tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例

二 从别的版本升级到2.0

自动卸载与其相关包

pip uninstall tensorflow

安装某版本

pip install --no-cache-dir tensorflow==x.xx (此处填写2.0)

tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例

验证

tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例

以上这篇tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python程序与服务器连接的WSGI接口
Apr 29 Python
Python中方法链的使用方法
Feb 23 Python
详解python的几种标准输出重定向方式
Aug 15 Python
Python用模块pytz来转换时区
Aug 19 Python
Django使用Mysql数据库已经存在的数据表方法
May 27 Python
python机器学习之KNN分类算法
Aug 29 Python
Pandas之MultiIndex对象的示例详解
Jun 25 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 Python
python中sklearn的pipeline模块实例详解
May 21 Python
python中的yield from语法快速学习
Nov 06 Python
Python time库的时间时钟处理
May 02 Python
仅用几行Python代码就能复制她的U盘文件?
Jun 26 Python
利用Vscode进行Python开发环境配置的步骤
Jun 22 #Python
Python Excel vlookup函数实现过程解析
Jun 22 #Python
宝塔面板成功部署Django项目流程(图文)
Jun 22 #Python
python和php哪个更适合写爬虫
Jun 22 #Python
如何理解python对象
Jun 21 #Python
什么是python的必选参数
Jun 21 #Python
什么是python的自省
Jun 21 #Python
You might like
PHP文件下载类
2006/12/06 PHP
PHP 数据结构 算法 三元组 Triplet
2011/07/02 PHP
Linux系统递归生成目录中文件的md5的方法
2015/06/29 PHP
PHP的Yii框架的常用日志操作总结
2015/12/08 PHP
PHP数组Key强制类型转换实现原理解析
2020/09/01 PHP
JavaScript创建对象的写法
2013/08/29 Javascript
JavaScript运行机制之事件循环(Event Loop)详解
2014/10/10 Javascript
JavaScript中的类数组对象介绍
2014/12/30 Javascript
JQuery选择器绑定事件及修改内容的方法
2015/01/23 Javascript
JavaScript实现网站访问次数统计代码
2015/08/12 Javascript
jQuery实现TAB风格的全国省份城市滑动切换效果代码
2015/08/24 Javascript
js仿百度登录页实现拖动窗口效果
2016/03/11 Javascript
Select下拉框模糊查询功能实现代码
2016/07/22 Javascript
Angular.js中ng-if、ng-show和ng-hide的区别介绍
2017/01/20 Javascript
Angular中响应式表单的三种更新值方法详析
2017/08/22 Javascript
JS和jQuery通过this获取html标签中的属性值(实例代码)
2017/09/11 jQuery
Vue中使用vux配置代码详解
2018/09/16 Javascript
ES6 更易于继承的类语法的使用
2019/02/11 Javascript
JavaScript生成随机验证码代码实例
2019/09/28 Javascript
如何检测JavaScript中的死循环示例详解
2020/08/30 Javascript
Python装饰器使用示例及实际应用例子
2015/03/06 Python
Pycharm设置界面全黑的方法
2018/05/23 Python
浅谈pytorch池化maxpool2D注意事项
2020/02/18 Python
django 模型中的计算字段实例
2020/05/19 Python
获取CSDN文章内容并转换为markdown文本的python
2020/09/06 Python
HTML5手指下滑弹出负一屏阻止移动端浏览器内置下拉刷新功能的实现代码
2020/04/10 HTML / CSS
解决html5中的video标签ios系统中无法播放使用的问题
2020/08/10 HTML / CSS
幼儿园母亲节活动方案
2014/03/10 职场文书
消防安全责任书范本
2014/04/15 职场文书
2015年感恩母亲节的演讲稿
2015/03/18 职场文书
护士求职自荐信
2015/03/25 职场文书
傅雷家书读书笔记
2015/06/29 职场文书
2015年大学生暑期实习报告
2015/07/13 职场文书
2016年学校安全教育月活动总结
2016/04/06 职场文书
MySQL Threads_running飙升与慢查询的相关问题解决
2021/05/08 MySQL
输入框跟随文字内容适配宽实现示例
2022/08/14 Javascript