解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈Python单向链表的实现
Dec 24 Python
深入学习Python中的装饰器使用
Jun 20 Python
python实现键盘控制鼠标移动
Nov 27 Python
对python3 中方法各种参数和返回值详解
Dec 15 Python
在numpy矩阵中令小于0的元素改为0的实例
Jan 26 Python
Python 列表去重去除空字符的例子
Jul 20 Python
Python解决pip install时出现的Could not fetch URL问题
Aug 01 Python
Django框架model模型对象验证实现方法分析
Oct 02 Python
Python yield的用法实例分析
Mar 06 Python
Python GUI编程学习笔记之tkinter中messagebox、filedialog控件用法详解
Mar 30 Python
django实现模型字段动态choice的操作
Apr 01 Python
python库Tsmoothie模块数据平滑化异常点抓取
Jun 10 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
PHP图片处理之图片旋转和图片翻转实例
2014/11/19 PHP
PHP Trait代码复用类与多继承实现方法详解
2019/06/17 PHP
php写入mysql中文乱码的实例解决方法
2019/09/17 PHP
javascript 动态调整图片尺寸实现代码
2009/12/28 Javascript
jQuery UI Dialog 创建友好的弹出对话框实现代码
2012/04/12 Javascript
javascript 禁用IE工具栏,导航栏等等实现代码
2013/04/01 Javascript
JS通过相同的name进行表格求和代码
2013/08/18 Javascript
基于jquery插件制作左右按钮与标题文字图片切换效果
2013/11/07 Javascript
JQuery选择器、过滤器大整理
2015/05/26 Javascript
JS实现文档加载完成后执行代码
2015/07/09 Javascript
jQuery表格插件datatables用法详解
2020/11/23 Javascript
js获取元素的标签名实现方法
2016/10/08 Javascript
关于angularJs指令的Scope(作用域)介绍
2016/10/25 Javascript
进阶之初探nodeJS
2017/01/24 NodeJs
jQuery插件HighCharts绘制2D圆环图效果示例【附demo源码下载】
2017/03/09 Javascript
基于JavaScript实现移动端无限加载分页
2017/03/27 Javascript
vue使用中的内存泄漏【推荐】
2018/07/10 Javascript
jquery ajax加载数据前台渲染方式 不用for遍历的方法
2018/08/09 jQuery
Vue自定义弹窗指令的实现代码
2018/08/13 Javascript
详解vue组件中使用路由方法
2019/02/12 Javascript
使用typescript改造koa开发框架的实现
2020/02/04 Javascript
jQuery 实现扁平式小清新导航
2020/07/07 jQuery
vue+elementUI(el-upload)图片压缩,默认同比例压缩操作
2020/08/10 Javascript
vue 组件基础知识总结
2021/01/26 Vue.js
Python3 把一个列表按指定数目分成多个列表的方式
2019/12/25 Python
python 实现有道翻译功能
2021/02/26 Python
serialVersionUID具有什么样的特征
2014/02/20 面试题
小学英语教学反思
2014/01/30 职场文书
超市促销活动方案
2014/03/05 职场文书
住宅使用说明书
2014/05/09 职场文书
文体活动总结
2015/02/04 职场文书
消防宣传语大全
2015/07/13 职场文书
2015年社区国庆节活动总结
2015/07/30 职场文书
学习习近平主席讲话心得体会
2016/01/20 职场文书
vue2实现provide inject传递响应式
2021/05/21 Vue.js
MySQL的安装与配置详细教程
2021/06/26 MySQL