keras 自定义loss层+接受输入实例


Posted in Python onJune 28, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python迭代器与生成器详解
Mar 10 Python
python中json格式数据输出的简单实现方法
Oct 31 Python
详解python中的 is 操作符
Dec 26 Python
分享Pycharm中一些不为人知的技巧
Apr 03 Python
关于Tensorflow中的tf.train.batch函数的使用
Apr 24 Python
python定时关机小脚本
Jun 20 Python
python logging重复记录日志问题的解决方法
Jul 12 Python
python PrettyTable模块的安装与简单应用
Jan 11 Python
python实现简单俄罗斯方块
Mar 13 Python
python简单的三元一次方程求解实例
Apr 02 Python
Python-openpyxl表格读取写入的案例详解
Nov 02 Python
python中K-means算法基础知识点
Jan 25 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
You might like
php日历[测试通过]
2008/03/27 PHP
PHP使用CURL实现对带有验证码的网站进行模拟登录的方法
2014/07/23 PHP
php多线程实现方法及用法实例详解
2015/10/26 PHP
将CKfinder整合进CKEditor3.0的新方法
2010/01/10 Javascript
jQuery 遍历json数组的实现代码
2020/09/22 Javascript
Jquery知识点二 jquery下对数组的操作
2011/01/15 Javascript
javascript使用activex控件的代码
2011/01/27 Javascript
jQuery EasyUI API 中文文档 - NumberSpinner数值微调器使用介绍
2011/10/21 Javascript
PHPExcel中的一些常用方法汇总
2015/01/23 Javascript
一系列Bootstrap导航条使用方法分享
2016/04/29 Javascript
javascript中sort排序实例详解
2016/07/24 Javascript
浅谈EasyUi ComBotree树修改 父节点选择的问题
2016/11/07 Javascript
DropDownList控件绑定数据源的三种方法
2016/12/24 Javascript
Bootstrap中datetimepicker使用小结
2016/12/28 Javascript
微信小程序教程系列之视图层的条件渲染(10)
2017/04/19 Javascript
vue router下的html5 history在iis服务器上的设置方法
2017/10/18 Javascript
详解如何使用webpack在vue项目中写jsx语法
2017/11/08 Javascript
JavaScript变量Dom对象的所有属性
2020/04/30 Javascript
JS highcharts实现动态曲线代码示例
2020/10/16 Javascript
编写Python脚本批量下载DesktopNexus壁纸的教程
2015/05/06 Python
Python random模块用法解析及简单示例
2017/12/18 Python
python复制文件到指定目录的实例
2018/04/27 Python
Python多线程同步---文件读写控制方法
2019/02/12 Python
解决Keras 中加入lambda层无法正常载入模型问题
2020/06/16 Python
python简单实现插入排序实例代码
2020/12/16 Python
Python实现钉钉/企业微信自动打卡的示例代码
2021/02/02 Python
CHARLES & KEITH英国官网:新加坡时尚品牌
2018/07/04 全球购物
For Art’s Sake官网:手工制作的奢华眼镜
2018/12/15 全球购物
英国时尚首饰品牌:Missoma
2020/06/29 全球购物
房地产促销活动方案
2014/03/01 职场文书
《李时珍夜宿古寺》教学反思
2014/04/09 职场文书
实体类或对象序列化时,忽略为空属性的操作
2021/06/30 Java/Android
node.js使用express-fileupload中间件实现文件上传
2021/07/16 Javascript
php修改word的实例方法
2021/11/17 PHP
python中数组和列表的简单实例
2022/03/25 Python
讲解MySQL增删改操作
2022/05/06 MySQL