解决Keras 中加入lambda层无法正常载入模型问题


Posted in Python onJune 16, 2020

刚刚解决了这个问题,现在记录下来

问题描述

当使用lambda层加入自定义的函数后,训练没有bug,载入保存模型则显示Nonetype has no attribute 'get'

问题解决方法:

这个问题是由于缺少config信息导致的。lambda层在载入的时候需要一个函数,当使用自定义函数时,模型无法找到这个函数,也就构建不了。

m = load_model(path,custom_objects={"reduce_mean":self.reduce_mean,"slice":self.slice})

其中,reduce_mean 和slice定义如下

def slice(self,x, turn):
    """ Define a tensor slice function
    """
    return x[:, turn, :, :]
  def reduce_mean(self, X):
    return K.mean(X, axis=-1)

补充知识:含有Lambda自定义层keras模型,保存遇到的问题及解决方案

一,许多应用,keras含有的层已经不能满足要求,需要透过Lambda自定义层来实现一些layer,这个情况下,只能保存模型的权重,无法使用model.save来保存模型。

保存时会报

TypeError: can't pickle _thread.RLock objects

二,解决方案,为了便于后续的部署,可以转成tensorflow的PB进行部署。

from keras.models import load_model
import tensorflow as tf
import os, sys
from keras import backend as K
from tensorflow.python.framework import graph_util, graph_io

def h5_to_pb(h5_weight_path, output_dir, out_prefix="output_", log_tensorboard=True):
  if not os.path.exists(output_dir):
    os.mkdir(output_dir)
  h5_model = build_model()
  h5_model.load_weights(h5_weight_path)
  out_nodes = []
  for i in range(len(h5_model.outputs)):
    out_nodes.append(out_prefix + str(i + 1))
    tf.identity(h5_model.output[i], out_prefix + str(i + 1))
  model_name = os.path.splitext(os.path.split(h5_weight_path)[-1])[0] + '.pb'
  sess = K.get_session()
  init_graph = sess.graph.as_graph_def()
  main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
  graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
  if log_tensorboard:
    from tensorflow.python.tools import import_pb_to_tensorboard
    import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)

def build_model():
  inputs = Input(shape=(784,), name='input_img')
  x = Dense(64, activation='relu')(inputs)
  x = Dense(64, activation='relu')(x)
  y = Dense(10, activation='softmax')(x)
  h5_model = Model(inputs=inputs, outputs=y)
  return h5_model

if __name__ == '__main__':
  if len(sys.argv) == 3:
    # usage: python3 h5_to_pb.py h5_weight_path output_dir
    h5_to_pb(h5_weight_path=sys.argv[1], output_dir=sys.argv[2])

以上这篇解决Keras 中加入lambda层无法正常载入模型问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python urls.py的三种配置写法实例详解
Apr 28 Python
详谈python http长连接客户端
Jun 12 Python
Python 获取当前所在目录的方法详解
Aug 02 Python
Python中 传递值 和 传递引用 的区别解析
Feb 22 Python
对pandas replace函数的使用方法小结
May 18 Python
python基础学习之如何对元组各个元素进行命名详解
Jul 12 Python
numpy数组广播的机制
Jul 12 Python
Python无头爬虫下载文件的实现
Apr 02 Python
python 模拟在天空中放风筝的示例代码
Apr 21 Python
python实战之一步一步教你绘制小猪佩奇
Apr 22 Python
Python基础之操作MySQL数据库
May 06 Python
教你如何用Python实现人脸识别(含源代码)
Jun 23 Python
结束运行python的方法
Jun 16 #Python
深入理解Python 多线程
Jun 16 #Python
keras.layer.input()用法说明
Jun 16 #Python
python适合做数据挖掘吗
Jun 16 #Python
Python+PyQt5+MySQL实现天气管理系统
Jun 16 #Python
Python实现SMTP邮件发送
Jun 16 #Python
python语言中有算法吗
Jun 16 #Python
You might like
用PHP实现将GB编码转换为UTF8
2006/11/25 PHP
站长助手-网站web在线管理程序 v1.0 下载
2007/05/12 PHP
探讨:php中在foreach中使用foreach ($arr as &$value) 这种类型的解释
2013/06/24 PHP
php创建类并调用的实例方法
2019/09/25 PHP
goto语法在PHP中的使用教程
2020/09/17 PHP
10个基于Jquery的幻灯片插件教程
2010/10/29 Javascript
JavaScript 模式之工厂模式(Factory)应用介绍
2012/11/15 Javascript
js动态创建上传表单通过iframe模拟Ajax实现无刷新
2014/02/20 Javascript
AngularJS基础学习笔记之控制器
2015/05/10 Javascript
Javascript removeChild()删除节点及删除子节点的方法
2015/12/27 Javascript
jQuery Dialog对话框事件用法实例分析
2016/05/10 Javascript
JS实现最简单的冒泡排序算法
2017/02/15 Javascript
关于定制FileField中的上传文件名称问题
2017/08/22 Javascript
vue2.0之多页面的开发的示例
2018/01/30 Javascript
jquery 动态遍历select 赋值的实例
2018/09/12 jQuery
深入理解JavaScript 中的执行上下文和执行栈
2018/10/23 Javascript
深入理解js A*寻路算法原理与具体实现过程
2018/12/13 Javascript
vue实现鼠标移入移出事件代码实例
2019/03/27 Javascript
JS前端知识点 运算符优先级,URL编码与解码,String,Math,arguments操作整理总结
2019/06/27 Javascript
详解ES6新增字符串扩张方法includes()、startsWith()、endsWith()
2020/05/12 Javascript
Pyspider中给爬虫伪造随机请求头的实例
2018/05/07 Python
Django unittest 设置跳过某些case的方法
2018/12/26 Python
Python 异步协程函数原理及实例详解
2019/11/13 Python
python飞机大战 pygame游戏创建快速入门详解
2019/12/17 Python
Python3操作读写CSV文件使用包过程解析
2020/04/10 Python
如何Tkinter模块编写Python图形界面
2020/10/14 Python
matplotlib自定义鼠标光标坐标格式的实现
2021/01/08 Python
CSS实现限制字数功能当对象内文本溢出时显示省略标记
2014/08/20 HTML / CSS
迪卡侬比利时官网:Decathlon比利时
2019/12/28 全球购物
比较基础的php面试题及答案-填空题
2014/04/26 面试题
"序列点" 是什么
2016/07/29 面试题
募捐倡议书怎么写
2014/05/14 职场文书
大学生入党积极分子党校学习思想汇报
2014/10/25 职场文书
物业项目经理岗位职责
2015/04/01 职场文书
仓库统计员岗位职责
2015/04/14 职场文书
红色经典电影观后感
2015/06/18 职场文书