keras 自定义loss层+接受输入实例


Posted in Python onJune 28, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python对指定目录下文件进行批量重命名的方法
Apr 18 Python
Django的数据模型访问多对多键值的方法
Jul 21 Python
使用Python实现BT种子和磁力链接的相互转换
Nov 09 Python
python3.x上post发送json数据
Mar 04 Python
python如何定义带参数的装饰器
Mar 20 Python
使用python对文件中的单词进行提取的方法示例
Dec 21 Python
Python小进度条显示代码
Mar 05 Python
python3实现表白神器
Apr 09 Python
PyQT实现菜单中的复制,全选和清空的功能的方法
Jun 17 Python
Python 共享变量加锁、释放详解
Aug 28 Python
Python动态导入模块:__import__、importlib、动态导入的使用场景实例分析
Mar 30 Python
python神经网络编程之手写数字识别
May 08 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
You might like
ThinkPHP入口文件设置及相关注意事项分析
2014/12/05 PHP
CodeIgniter框架常见用法工作总结
2017/03/16 PHP
laravel框架上传图片实现实时预览功能
2019/10/14 PHP
PHP+MySql实现一个简单的留言板
2020/07/19 PHP
javascript 兼容FF的onmouseenter和onmouseleave的代码
2008/07/19 Javascript
jQuery timers计时器简单应用说明
2010/10/28 Javascript
Struts2的s:radio标签使用及用jquery添加change事件
2013/04/08 Javascript
node.js实现BigPipe详解
2014/12/05 Javascript
JS实现FLASH幻灯片图片切换效果的方法
2015/03/04 Javascript
JS组件Bootstrap导航条使用方法详解
2016/04/29 Javascript
文本框只能输入数字的js代码(含小数点)
2016/07/10 Javascript
Vue-resource实现ajax请求和跨域请求示例
2017/02/23 Javascript
基于vue循环列表时点击跳转页面的方法
2018/08/31 Javascript
模块化react-router配置方法详解
2019/06/03 Javascript
使用layui监听器监听select下拉框,事件绑定不成功的解决方法
2019/09/28 Javascript
vue搜索页开发实例代码详解(热门搜索,历史搜索,淘宝接口演示)
2020/04/11 Javascript
webpack安装配置与常见使用过程详解(结合vue)
2020/06/01 Javascript
Python如何基于selenium实现自动登录博客园
2019/12/16 Python
NumPy统计函数的实现方法
2020/01/21 Python
python实现飞机大战游戏(pygame版)
2020/10/26 Python
selenium+超级鹰实现模拟登录12306
2021/01/24 Python
The Hut德国站点:时装、家居用品、美容等
2016/09/23 全球购物
Schutz鞋官方网站:Schutz Shoes
2017/12/13 全球购物
Kangol帽子官网:坎戈尔袋鼠
2018/09/26 全球购物
西班牙土拨鼠床垫公司,感觉在云端:Marmota
2019/03/18 全球购物
学习自我鉴定
2014/02/01 职场文书
酒店营销策划方案
2014/02/07 职场文书
2014年大学生党课心得体会范文
2014/03/29 职场文书
六一亲子活动总结
2014/07/01 职场文书
公司自我介绍演讲稿
2014/08/21 职场文书
2014教师党员自我评议(5篇)
2014/09/20 职场文书
学生上课看漫画的检讨书
2014/09/26 职场文书
2015小学音乐教师个人工作总结
2015/07/21 职场文书
CSS3实现的侧滑菜单
2021/04/27 HTML / CSS
Go 在 MongoDB 中常用查询与修改的操作
2021/05/07 Golang
Java实现经典游戏泡泡堂的示例代码
2022/04/04 Java/Android