Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python正则表达式的使用
Jun 12 Python
Python学习小技巧总结
Jun 10 Python
python实现电脑自动关机
Jun 20 Python
django的ORM操作 增加和查询
Jul 26 Python
selenium+python实现自动登陆QQ邮箱并发送邮件功能
Dec 13 Python
python实现查找所有程序的安装信息
Feb 18 Python
MxNet预训练模型到Pytorch模型的转换方式
May 25 Python
Python代码注释规范代码实例解析
Aug 14 Python
python类共享变量操作
Sep 03 Python
windows系统Tensorflow2.x简单安装记录(图文)
Jan 18 Python
聊聊python在linux下与windows下导入模块的区别说明
Mar 03 Python
Python借助with语句实现代码段只执行有限次
Mar 23 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
php模板中出现空行解决方法
2011/03/08 PHP
php笔记之:AOP的应用
2013/04/24 PHP
基于curl数据采集之正则处理函数get_matches的使用
2013/04/28 PHP
PHP时间戳 strtotime()使用方法和技巧
2013/10/29 PHP
php中hashtable实现示例分享
2014/02/13 PHP
php实现的漂亮分页方法
2014/04/17 PHP
php设置页面超时时间解决方法
2015/09/22 PHP
CodeIgniter配置之autoload.php自动加载用法分析
2016/01/20 PHP
php中namespace use用法实例分析
2016/01/22 PHP
css3实现背景模糊的三种方式
2021/03/09 HTML / CSS
javascript 延迟加载技术(lazyload)简单实现
2011/01/17 Javascript
javascript 学习笔记(八)javascript对象
2011/04/12 Javascript
Jquery的hide及toggle方法让超链接慢慢消失
2013/09/06 Javascript
jquery easyui combobox模糊过滤(示例代码)
2013/11/30 Javascript
Nodejs的express使用教程
2015/11/23 NodeJs
完美解决IE不支持Data.parse()的问题
2016/11/24 Javascript
canvas绘制的直线动画
2017/01/23 Javascript
vue.js树形组件之删除双击增加分支实例代码
2017/02/28 Javascript
原生js实现吸顶效果
2017/03/13 Javascript
jquery实现图片放大点击切换
2017/06/06 jQuery
Vue-cli@3.0 插件系统简析
2018/09/05 Javascript
微信小程序实现通过js操作wxml的wxss属性示例
2018/12/06 Javascript
node中IO以及定时器优先级详解
2019/05/10 Javascript
云服务器部署Node.js项目的方法步骤(小白系列)
2020/03/23 Javascript
使用 Opentype.js 生成字体子集的实例代码详解
2020/05/25 Javascript
Python基于贪心算法解决背包问题示例
2017/11/27 Python
使用Python批量修改文件名的代码实例
2019/01/24 Python
详解利用css3的var()实现运行时改变scss的变量值
2021/03/02 HTML / CSS
记一次高分屏下canvas模糊问题
2020/02/17 HTML / CSS
工程招投标邀请书
2014/01/30 职场文书
公司承诺书范文
2014/05/19 职场文书
幼儿园迎国庆65周年活动策划方案
2014/09/16 职场文书
领导干部作风建设总结
2014/10/23 职场文书
布达拉宫的导游词
2015/02/02 职场文书
解决Jupyter-notebook不弹出默认浏览器的问题
2021/03/30 Python
基于PyTorch实现一个简单的CNN图像分类器
2021/05/29 Python