Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python bsddb模块操作Berkeley DB数据库介绍
Apr 08 Python
Python 实现链表实例代码
Apr 07 Python
在Python中增加和插入元素的示例
Nov 01 Python
Python 从一个文件中调用另一个文件的类方法
Jan 10 Python
Python3.5面向对象与继承图文实例详解
Apr 24 Python
python自定义函数实现最大值的输出方法
Jul 09 Python
Pyqt5自适应布局实例
Dec 13 Python
Python求凸包及多边形面积教程
Apr 12 Python
使用python实现下载我们想听的歌曲,速度超快
Jul 09 Python
安装Anaconda3及使用Jupyter的方法
Oct 27 Python
selenium如何定位span元素的实现
Jan 13 Python
Pandas实现DataFrame的简单运算、统计与排序
Mar 31 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
php报表之jpgraph柱状图实例代码
2011/08/22 PHP
zf框架db类的分页示例分享
2014/03/14 PHP
PHP实现创建一个RPC服务操作示例
2020/02/23 PHP
js 模拟气泡屏保效果代码
2010/07/10 Javascript
jquery 选项卡效果 新手代码
2011/07/08 Javascript
jQuery在页面加载时动态修改图片尺寸的方法
2015/03/20 Javascript
AngularJS 所有版本下载地址
2016/09/14 Javascript
JavaScript中常用的验证reg
2016/10/13 Javascript
基于bootstrap的文件上传控件bootstrap fileinput
2016/12/23 Javascript
浅谈javascript的闭包
2017/01/23 Javascript
老生常谈js中0到底是 true 还是 false
2017/03/08 Javascript
JS简单验证上传文件类型的方法
2017/04/17 Javascript
Vue2.0父组件与子组件之间的事件发射与接收实例代码
2017/09/19 Javascript
js实时监控文本框输入字数的实例代码
2018/01/18 Javascript
在AngularJs中设置请求头信息(headers)的方法及不同方法的比较
2018/09/04 Javascript
Vue使用lodop实现打印小结
2019/07/06 Javascript
微信小程序实现左侧滑栏过程解析
2019/08/26 Javascript
js闭包和垃圾回收机制示例详解
2021/03/01 Javascript
[56:13]DOTA2-DPC中国联赛定级赛 LBZS vs Phoenix BO3第一场 1月10日
2021/03/11 DOTA
基python实现多线程网页爬虫
2015/09/06 Python
python+opencv实现的简单人脸识别代码示例
2017/11/14 Python
python中使用iterrows()对dataframe进行遍历的实例
2018/06/09 Python
Python3爬虫学习入门教程
2018/12/11 Python
Django分组聚合查询实例分享
2020/04/29 Python
美国网上订购鲜花:FTD
2016/09/23 全球购物
EGO Shoes美国/加拿大:英国时髦鞋类品牌
2018/08/04 全球购物
马来西亚与新加坡长途巴士售票网站:BusOnlineTicket.com
2018/11/05 全球购物
墨尔本复古时尚品牌:Dangerfield
2018/12/12 全球购物
千禧酒店及度假村官方网站:Millennium Hotels and Resorts
2019/05/10 全球购物
给排水工程师岗位职责
2013/11/21 职场文书
大学生如何写自荐信
2014/01/08 职场文书
国家税务局干部作风整顿整改措施
2014/09/18 职场文书
2014统计局民主生活会对照检查材料思想汇报
2014/10/02 职场文书
Win11 引入 Windows 365 云操作系统,适应疫情期间混合办公模式:启动时直接登录、模
2022/04/06 数码科技
python数字图像处理之图像的批量处理
2022/06/28 Python
win10电脑右下角输入法图标不见了?Win10右下角不显示输入法的解决方法
2022/07/23 数码科技