keras实现theano和tensorflow训练的模型相互转换


Posted in Python onJune 19, 2020

我就废话不多说了,大家还是直接看代码吧~

</pre><pre code_snippet_id="1947416" snippet_file_name="blog_20161025_1_3331239" name="code" class="python">

# coding:utf-8
"""
If you want to load pre-trained weights that include convolutions (layers Convolution2D or Convolution1D),
be mindful of this: Theano and TensorFlow implement convolution in different ways (TensorFlow actually implements correlation, much like Caffe),
and thus, convolution kernels trained with Theano (resp. TensorFlow) need to be converted before being with TensorFlow (resp. Theano).
"""
from keras import backend as K
from keras.utils.np_utils import convert_kernel
from text_classifier import keras_text_classifier
import sys
 
def th2tf( model):
  import tensorflow as tf
  ops = []
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      ops.append(tf.assign(layer.W, converted_w).op)
  K.get_session().run(ops)
  return model
 
def tf2th(model):
  for layer in model.layers:
    if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
      original_w = K.get_value(layer.W)
      converted_w = convert_kernel(original_w)
      K.set_value(layer.W, converted_w)
  return model
 
def conv_layer_converted(tf_weights, th_weights, m = 0):
  """
  :param tf_weights:
  :param th_weights:
  :param m: 0-tf2th, 1-th2tf
  :return:
  """
  if m == 0: # tf2th
    tc = keras_text_classifier(weights_path=tf_weights)
    model = tc.loadmodel()
    model = tf2th(model)
    model.save_weights(th_weights)
  elif m == 1: # th2tf
    tc = keras_text_classifier(weights_path=th_weights)
    model = tc.loadmodel()
    model = th2tf(model)
    model.save_weights(tf_weights)
  else:
    print("0-tf2th, 1-th2tf")
    return
if __name__ == '__main__':
  if len(sys.argv) < 4:
    print("python tf_weights th_weights <0|1>\n0-tensorflow to theano\n1-theano to tensorflow")
    sys.exit(0)
  tf_weights = sys.argv[1]
  th_weights = sys.argv[2]
  m = int(sys.argv[3])
  conv_layer_converted(tf_weights, th_weights, m)

补充知识:keras学习之修改底层为TensorFlow还是theano

我们知道,keras的底层是TensorFlow或者theano

要知道我们是用的哪个为底层,只需要import keras即可显示

修改方法:

打开

keras实现theano和tensorflow训练的模型相互转换

修改

keras实现theano和tensorflow训练的模型相互转换

以上这篇keras实现theano和tensorflow训练的模型相互转换就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python深入学习之装饰器
Aug 31 Python
在Python中使用模块的教程
Apr 27 Python
Python实现简单的多任务mysql转xml的方法
Feb 08 Python
Python爬虫通过替换http request header来欺骗浏览器实现登录功能
Jan 07 Python
Pandas 对Dataframe结构排序的实现方法
Apr 10 Python
django表单实现下拉框的示例讲解
May 29 Python
PyQt4 treewidget 选择改变颜色,并设置可编辑的方法
Jun 17 Python
pytorch 改变tensor尺寸的实现
Jan 03 Python
python+adb+monkey实现Rom稳定性测试详解
Apr 23 Python
Python学习工具jupyter notebook安装及用法解析
Oct 23 Python
用Python实现童年贪吃蛇小游戏功能的实例代码
Dec 07 Python
pandas 数据类型转换的实现
Dec 29 Python
Keras 切换后端方式(Theano和TensorFlow)
Jun 19 #Python
python中怎么表示空值
Jun 19 #Python
Python调用OpenCV实现图像平滑代码实例
Jun 19 #Python
使用OpenCV对车道进行实时检测的实现示例代码
Jun 19 #Python
为什么python比较流行
Jun 19 #Python
查看keras的默认backend实现方式
Jun 19 #Python
Python图像阈值化处理及算法比对实例解析
Jun 19 #Python
You might like
thinkphp5.1框架模板布局与模板继承用法分析
2019/07/19 PHP
QQ登录简单实现代码
2021/03/09 Javascript
Visual Studio中的jQuery智能提示设置方法
2010/03/27 Javascript
JS字符串函数扩展代码
2011/09/13 Javascript
javascript针对DOM的应用分析(四)
2012/04/15 Javascript
javascript 使用 NodeList需要注意的问题
2013/03/04 Javascript
js异步加载的三种解决方案
2013/03/04 Javascript
JS特权方法定义作用以及与公有方法的区别
2013/03/18 Javascript
jquery 实现input输入什么div图层显示什么
2014/06/15 Javascript
探讨js字符串数组拼接的性能问题
2014/10/11 Javascript
Javascript基础知识盲点总结之函数
2016/05/15 Javascript
JS中的hasOwnProperty()和isPrototypeOf()属性实例详解
2016/08/11 Javascript
微信小程序登录换取token的教程
2018/05/31 Javascript
Vue使用NPM方式搭建项目
2018/10/25 Javascript
微信小程序授权登录解决方案的代码实例(含未通过授权解决方案)
2019/05/10 Javascript
vue学习笔记之给组件绑定原生事件操作示例
2020/02/27 Javascript
JavaScript中ES6规范中let和const的用法和区别
2020/08/06 Javascript
python函数返回多个值的示例方法
2013/12/04 Python
Python使用微信SDK实现的微信支付功能示例
2017/06/30 Python
Django 忘记管理员或忘记管理员密码 重设登录密码的方法
2018/05/30 Python
Python第三方库face_recognition在windows上的安装过程
2019/05/03 Python
python爬虫 猫眼电影和电影天堂数据csv和mysql存储过程解析
2019/09/05 Python
Python socket模块ftp传输文件过程解析
2019/11/05 Python
使用python批量转换文件编码为UTF-8的实现
2020/04/03 Python
Scrapy基于scrapy_redis实现分布式爬虫部署的示例
2020/09/29 Python
Python爬虫开发与项目实战
2020/12/16 Python
css3边框_动力节点Java学院整理
2017/07/11 HTML / CSS
Ellos丹麦:时尚和服装在线
2016/09/19 全球购物
英国奢华护肤、美容和Spa品牌:Temple Spa
2019/11/02 全球购物
美国折扣地毯销售网站:Rugs.com
2020/03/27 全球购物
JMS中Topic和Queue有什么区别
2013/05/15 面试题
Java平台和其他软件平台有什么不同
2015/06/05 面试题
应聘医药销售自荐书范文
2014/02/08 职场文书
学校办公室主任岗位职责
2015/04/01 职场文书
国王的演讲观后感
2015/06/03 职场文书
mysql sum(if())和count(if())的用法说明
2022/01/18 MySQL