Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现获取nginx服务器ip及流量统计信息功能示例
May 18 Python
对python捕获ctrl+c手工中断程序的两种方法详解
Dec 26 Python
python实时获取外部程序输出结果的方法
Jan 12 Python
Python3实现的反转单链表算法示例
Mar 08 Python
Python实现html转换为pdf报告(生成pdf报告)功能示例
May 04 Python
tensorflow 实现自定义梯度反向传播代码
Feb 10 Python
python实现滑雪者小游戏
Feb 22 Python
Python编程快速上手——强口令检测算法案例分析
Feb 29 Python
Python 实现使用空值进行赋值 None
Mar 12 Python
parser.add_argument中的action使用
Apr 20 Python
windows下python 3.9 Numpy scipy和matlabplot的安装教程详解
Nov 28 Python
python中altair可视化库实例用法
Jan 26 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
绿山咖啡和蓝山咖啡
2021/03/04 新手入门
PHP中如何判断AJAX提交的数据
2012/02/05 PHP
PHP session会话操作技巧小结
2016/09/27 PHP
php正则修正符用法实例详解
2016/12/29 PHP
laravel框架邮箱认证实现方法详解
2019/11/22 PHP
BOOM vs RR BO5 第二场 2.14
2021/03/10 DOTA
CL vs ForZe BO5 第二场 2.13
2021/03/10 DOTA
js获取input标签的输入值实现代码
2013/08/05 Javascript
用javascript对一个json数组深度赋值示例
2014/07/27 Javascript
AngularJs根据访问的页面动态加载Controller的解决方案
2015/02/04 Javascript
JavaScript的History API使搜索引擎抓取AJAX内容
2015/12/07 Javascript
jquery分隔Url的param方法(推荐)
2016/05/25 Javascript
JavaScript编程中实现对象封装特性的实例讲解
2016/06/24 Javascript
微信小程序 登录实例详解
2017/01/16 Javascript
防止重复发送 Ajax 请求
2017/02/15 Javascript
vue2.0+vuex+localStorage代办事项应用实现详解
2018/05/31 Javascript
VScode格式化ESlint方法(最全最好用方法)
2019/09/10 Javascript
详解element-ui中表单验证的三种方式
2019/09/18 Javascript
HTML+JS实现“代码雨”效果源码(黑客帝国文字下落效果)
2020/03/17 Javascript
原生js实现贪吃蛇游戏
2020/10/26 Javascript
微信小程序学习之自定义滚动弹窗
2020/12/20 Javascript
用python处理图片之打开\显示\保存图像的方法
2018/05/04 Python
Python 删除连续出现的指定字符的实例
2018/06/29 Python
python实现换位加密算法的示例
2018/10/14 Python
Python 循环终止语句的三种方法小结
2019/06/24 Python
python的pygal模块绘制反正切函数图像方法
2019/07/16 Python
Python使用get_text()方法从大段html中提取文本的实例
2019/08/27 Python
基于python实现百度语音识别和图灵对话
2020/11/02 Python
翻译学院毕业生自荐书
2014/02/02 职场文书
考研英语复习计划
2015/01/19 职场文书
五星级酒店前台接待岗位职责
2015/04/02 职场文书
2015年科室工作总结
2015/04/10 职场文书
药品开票员岗位职责
2015/04/15 职场文书
花田少年史观后感
2015/06/16 职场文书
Python3 类型标注支持操作
2021/06/02 Python
JavaScript异步操作中串行和并行
2021/11/20 Javascript