keras 自定义loss层+接受输入实例


Posted in Python onJune 28, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅析Python多线程下的变量问题
Apr 28 Python
Python中字典的基本知识初步介绍
May 21 Python
Python实现控制台进度条功能
Jan 04 Python
深入解析Python中的线程同步方法
Jun 14 Python
深入理解Python中range和xrange的区别
Nov 26 Python
python3利用smtplib通过qq邮箱发送邮件方法示例
Dec 03 Python
Python内建模块struct实例详解
Feb 02 Python
详谈Python中列表list,元祖tuple和numpy中的array区别
Apr 18 Python
python实现简单的单变量线性回归方法
Nov 08 Python
Python3多线程基础知识点
Feb 19 Python
Python 串口通信的实现
Sep 29 Python
Python return语句如何实现结果返回调用
Oct 15 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
You might like
DOTA2 探索永无止境 玩家自创强悍插眼攻略
2020/04/20 DOTA
PHP4与PHP5的时间格式问题
2008/02/17 PHP
php 文件缓存函数
2011/10/08 PHP
无法在发生错误时创建会话,请检查 PHP 或网站服务器日志,并正确配置 PHP 安装(win+linux)
2012/05/05 PHP
基于Zend的Config机制的应用分析
2013/05/02 PHP
PHP数组函数知识汇总
2016/05/12 PHP
php慢查询日志和错误日志使用详解
2021/02/27 PHP
CL vs ForZe BO5 第五场 2.13
2021/03/10 DOTA
jquery中获取id值方法小结
2013/09/22 Javascript
50 个 jQuery 插件可将你的网站带到另外一个高度
2016/04/26 Javascript
基于JavaScript实现在新的tab页打开url
2016/08/04 Javascript
jQuery实现可拖拽的许愿墙效果【附demo源码下载】
2016/09/14 Javascript
JavaScript评论点赞功能的实现方法
2017/03/13 Javascript
React学习之事件绑定的几种方法对比
2017/09/24 Javascript
react以create-react-app为基础创建项目
2018/03/14 Javascript
VUE简单的定时器实时刷新的实现方法
2019/01/20 Javascript
Openlayers显示地理位置坐标的方法
2020/09/28 Javascript
[01:38]DOTA2 2015国际邀请赛中国区预选赛 Showopen
2015/06/01 DOTA
python的绘图工具matplotlib使用实例
2014/07/03 Python
python获取一组汉字拼音首字母的方法
2015/07/01 Python
Python网络爬虫与信息提取(实例讲解)
2017/08/29 Python
Python编程求质数实例代码
2018/01/31 Python
python with (as)语句实例详解
2020/02/04 Python
PyTorch中torch.tensor与torch.Tensor的区别详解
2020/05/18 Python
Python 抓取数据存储到Redis中的操作
2020/07/16 Python
canvas实现飞机打怪兽射击小游戏的示例代码
2018/07/09 HTML / CSS
Eastbay官网:美国最大的运动鞋网络零售商
2016/07/27 全球购物
巴黎卡诗加拿大官网:Kérastase加拿大
2018/11/12 全球购物
出纳岗位职责范本
2013/12/01 职场文书
教师自我鉴定范文
2014/03/20 职场文书
班子四风对照检查材料
2014/08/21 职场文书
单位委托书格式范本
2014/09/29 职场文书
撤诉书怎么写
2015/05/19 职场文书
2016先进工作者事迹材料
2016/02/25 职场文书
2019XX公司员工考核管理制度!
2019/08/07 职场文书
MySQL系列之二 多实例配置
2021/07/02 MySQL