解决Keras 自定义层时遇到版本的问题


Posted in Python onJune 16, 2020

在2.2.0版本前,

from keras import backend as K
from keras.engine.topology import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # 为该层创建一个可训练的权重
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # 一定要在最后调用它
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

2.2.0 版本时:

from keras import backend as K
from keras.layers import Layer
 
class MyLayer(Layer):
 
  def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(MyLayer, self).__init__(**kwargs)
 
  def build(self, input_shape):
    # Create a trainable weight variable for this layer.
    self.kernel = self.add_weight(name='kernel', 
                   shape=(input_shape[1], self.output_dim),
                   initializer='uniform',
                   trainable=True)
    super(MyLayer, self).build(input_shape) # Be sure to call this at the end
 
  def call(self, x):
    return K.dot(x, self.kernel)
 
  def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_dim)

如果你遇到:

<module> from keras.engine.base_layer import InputSpec ModuleNotFoundError: No module named 'keras.engine.base_layer'

不妨试试另一种引入!

补充知识:Keras自定义损失函数在场景分类的使用

在做图像场景分类的过程中,需要自定义损失函数,遇到很多坑。Keras自带的损失函数都在losses.py文件中。(以下默认为分类处理)

#losses.py
#y_true是分类的标签,y_pred是分类中预测值(这里指,模型最后一层为softmax层,输出的是每个类别的预测值)
def mean_squared_error(y_true, y_pred):
  return K.mean(K.square(y_pred - y_true), axis=-1)
def mean_absolute_error(y_true, y_pred):
  return K.mean(K.abs(y_pred - y_true), axis=-1)
def mean_absolute_percentage_error(y_true, y_pred):
  diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),K.epsilon(),None))
  return 100. * K.mean(diff, axis=-1)
def mean_squared_logarithmic_error(y_true, y_pred):
  first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
  second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
  return K.mean(K.square(first_log - second_log), axis=-1)
def squared_hinge(y_true, y_pred):
  return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)

这里面简单的来说,y_true就是训练数据的标签,y_pred就是模型训练时经过softmax层的预测值。经过计算,得出损失值。

那么我们要新建损失函数totoal_loss,就要在本文件下,进行新建。

def get_loss(labels,features, alpha,lambda_c,lambda_g,num_classes):
  #由于涉及研究内容,详细代码不做公开
  return loss
#total_loss(y_true,y_pred),y_true代表标签(类别),y_pred代表模型的输出
#( 如果是模型中间层输出,即代表特征,如果模型输出是经过softmax就是代表分类预测值)
#其他有需要的参数也可以写在里面
def total_loss(y_true,y_pred):
    git_loss=get_loss(y_true,y_pred,alpha=0.5,lambda_c=0.001,lambda_g=0.001,num_classes=45)
    return git_loss

自定义损失函数写好之后,可以进行使用了。这里,我使用交叉熵损失函数和自定义损失函数一起使用。

#这里使用vgg16模型
model = VGG16(input_tensor=image_input, include_top=True,weights='imagenet')
model.summary()
#fc2层输出为特征
last_layer = model.get_layer('fc2').output
#获取特征
feature = last_layer
#softmax层输出为各类的预测值
out = Dense(num_classes,activation = 'softmax',name='predictions')(last_layer)
#该模型有一个输入image_input,两个输出out,feature
custom_vgg_model = Model(inputs = image_input, outputs = [feature,out])
custom_vgg_model.summary()
#优化器,梯度下降
sgd = optimizers.SGD(lr=learn_Rate,decay=decay_Rate,momentum=0.9,nesterov=True)
#这里面,刚才有两个输出,这里面使用两个损失函数,total_loss对应的是fc2层输出的特征
#categorical_crossentropy对应softmax层的损失函数
#loss_weights两个损失函数的权重
custom_vgg_model.compile(loss={'fc2': 'total_loss','predictions': "categorical_crossentropy"},
             loss_weights={'fc2': 1, 'predictions':1},optimizer= sgd,
                   metrics={'predictions': 'accuracy'})
#这里使用dummy1,dummy2做演示,为0
dummy1 = np.zeros((y_train.shape[0],4096))
dummy2 = np.zeros((y_test.shape[0],4096))
#模型的输入输出必须和model.fit()中x,y两个参数维度相同
#dummy1的维度和fc2层输出的feature维度相同,y_train和softmax层输出的预测值维度相同
#validation_data验证数据集也是如此,需要和输出层的维度相同
hist = custom_vgg_model.fit(x = X_train,y = {'fc2':dummy1,'predictions':y_train},batch_size=batch_Sizes,
                epochs=epoch_Times, verbose=1,validation_data=(X_test, {'fc2':dummy2,'predictions':y_test}))

写到这里差不多就可以了,不够详细,以后再做补充。

以上这篇解决Keras 自定义层时遇到版本的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3基础之条件与循环控制实例解析
Aug 13 Python
python使用PythonMagick将jpg图片转换成ico图片的方法
Mar 26 Python
python统计文本文件内单词数量的方法
May 30 Python
Python中的探索性数据分析(功能式)
Dec 22 Python
Python实现替换文件中指定内容的方法
Mar 19 Python
python实现多层感知器
Jan 18 Python
python flask框架实现重定向功能示例
Jul 02 Python
python实现大文件分割与合并
Jul 22 Python
Python 函数用法简单示例【定义、参数、返回值、函数嵌套】
Sep 20 Python
基于Python新建用户并产生随机密码过程解析
Oct 08 Python
如何在Python对Excel进行读取
Jun 04 Python
Python 排序最长英文单词链(列表中前一个单词末字母是下一个单词的首字母)
Dec 14 Python
Keras实现支持masking的Flatten层代码
Jun 16 #Python
Keras自定义实现带masking的meanpooling层方式
Jun 16 #Python
浅谈keras 的抽象后端(from keras import backend as K)
Jun 16 #Python
记录模型训练时loss值的变化情况
Jun 16 #Python
python实现批量转换图片为黑白
Jun 16 #Python
在keras中实现查看其训练loss值
Jun 16 #Python
安装python3.7编译器后如何正确安装opnecv的方法详解
Jun 16 #Python
You might like
PHP扩展模块Pecl、Pear以及Perl的区别
2014/04/09 PHP
php获取指定日期之间的各个周和月的起止时间
2014/11/24 PHP
PHP不使用递归的无限级分类简单实例
2016/11/05 PHP
详解yii2实现分库分表的方案与思路
2017/02/03 PHP
详解php用static方法的原因
2018/09/12 PHP
laravel 自定义常量的两种方案
2019/10/14 PHP
Extjs 4.x 得到form CheckBox 复选框的值
2014/05/04 Javascript
Node.js中HTTP模块与事件模块详解
2014/11/14 Javascript
AngularJs中route的使用方法和配置
2016/02/04 Javascript
JavaScript入门教程之引用类型
2016/05/04 Javascript
JS针对浏览器窗口关闭事件的监听方法集锦
2016/06/24 Javascript
angularjs ocLazyLoad分步加载js文件实例
2017/01/17 Javascript
bootstrap table实现单击单元格可编辑功能
2017/03/28 Javascript
详解webpack + react + react-router 如何实现懒加载
2017/11/20 Javascript
jQuery实现动态显示select下拉列表数据的方法
2018/02/05 jQuery
微信小程序如何实现在线客服功能
2019/10/16 Javascript
Django集成百度富文本编辑器uEditor攻略
2014/07/04 Python
Python语言的12个基础知识点小结
2014/07/10 Python
Python内置的HTTP协议服务器SimpleHTTPServer使用指南
2016/03/30 Python
python使用psutil模块获取系统状态
2016/08/27 Python
python http接口自动化脚本详解
2018/01/02 Python
Django学习教程之静态文件的调用详解
2018/05/08 Python
python实现飞机大战微信小游戏
2020/03/21 Python
50行Python代码获取高考志愿信息的实现方法
2019/07/23 Python
Python类反射机制使用实例解析
2019/12/30 Python
什么是python的函数体
2020/06/19 Python
python 多进程和协程配合使用写入数据
2020/10/30 Python
H5 canvas实现贪吃蛇小游戏
2017/07/28 HTML / CSS
幼儿园教师培训制度
2014/01/16 职场文书
《秋游》教学反思
2014/04/24 职场文书
村干部群众路线教育活动对照检查材料
2014/10/01 职场文书
部门群众路线教育实践活动对照检查材料思想汇报
2014/10/07 职场文书
施工安全责任协议书
2016/03/23 职场文书
2016年读书月活动总结范文
2016/04/06 职场文书
python中__slots__节约内存的具体做法
2021/07/04 Python
mysql使用 not int 子查询隐含陷阱
2022/04/12 MySQL