Keras之自定义损失(loss)函数用法说明


Posted in Python onJune 10, 2020

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:

"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return K.cast(K.equal(K.argmax(y_true, axis=-1),
       K.argmax(y_pred, axis=-1)),
     K.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if K.ndim(y_true) == K.ndim(y_pred):
  y_true = K.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = K.argmax(y_pred, axis=-1)
 y_pred_labels = K.cast(y_pred_labels, K.floatx())
 return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
     axis=-1)
 
# Aliases
 
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=None):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise ValueError('Could not interpret '
       'metric function identifier:', identifier)

以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
理解python多线程(python多线程简明教程)
Jun 09 Python
如何在Python中编写并发程序
Feb 27 Python
python二分查找算法的递归实现方法
May 12 Python
详解python调度框架APScheduler使用
Mar 28 Python
Django REST为文件属性输出完整URL的方法
Dec 18 Python
Python用csv写入文件_消除空余行的方法
Jul 06 Python
解决Python一行输出不显示的问题
Dec 03 Python
python 导入数据及作图的实现
Dec 03 Python
Python sep参数使用方法详解
Feb 12 Python
Python利用 utf-8-sig 编码格式解决写入 csv 文件乱码问题
Feb 21 Python
python爬取招聘要求等信息实例
Nov 20 Python
Python必备技巧之函数的使用详解
Apr 04 Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 #Python
Python中的__init__作用是什么
Jun 09 #Python
You might like
PHP面向对象的进阶学习(抽像类、接口、final、类常量)
2012/05/07 PHP
php下载文件的代码示例
2012/06/29 PHP
访问编码后的中文URL返回404错误的解决方法
2014/08/20 PHP
用 Composer构建自己的 PHP 框架之基础准备
2014/10/30 PHP
深入探究PHP的多进程编程方法
2015/08/18 PHP
php实现保存周期为1天的购物车类
2017/07/07 PHP
js左右弹性滚动对联广告代码分享
2014/02/19 Javascript
js QQ客服悬浮效果实现代码
2014/12/12 Javascript
javascript实现表格排序 编辑 拖拽 缩放
2015/01/02 Javascript
nodejs开发微博实例
2015/03/25 NodeJs
Javascript中arguments和arguments.callee的区别浅析
2015/04/24 Javascript
Javascript实现div层渐隐效果的方法
2015/05/30 Javascript
Bootstrap打造一个左侧折叠菜单的系统模板(一)
2016/05/17 Javascript
利用JS实现简单的瀑布流加载图片效果
2017/04/22 Javascript
Ionic项目中Native Camera的使用方法
2017/06/07 Javascript
JS实现给json数组动态赋值的方法示例
2020/03/19 Javascript
vue环境搭建简单教程
2017/11/07 Javascript
浅谈关于JS下大批量异步任务按顺序执行解决方案一点思考
2019/01/08 Javascript
浅谈Javascript中的对象和继承
2019/04/19 Javascript
Vue基于iview实现登录密码的显示与隐藏功能
2020/03/06 Javascript
JavaScript Image对象实现原理实例解析
2020/08/26 Javascript
vue cli 3.0通用打包配置代码,不分一二级目录
2020/09/02 Javascript
[00:32]DOTA2上海特级锦标赛 COL战队宣传片
2016/03/04 DOTA
[01:09]模型精美,特效酷炫!TI9不朽宝藏Ⅰ鉴赏
2019/05/10 DOTA
python类定义的讲解
2013/11/01 Python
python利用urllib和urllib2访问http的GET/POST详解
2017/09/27 Python
python3实现163邮箱SMTP发送邮件
2018/05/22 Python
python从ftp获取文件并下载到本地
2020/12/05 Python
财务会计实训报告
2014/11/05 职场文书
经验交流材料格式
2014/12/30 职场文书
2015年餐厅服务员工作总结
2015/04/23 职场文书
个人收入证明格式
2015/06/24 职场文书
《打电话》教学反思
2016/02/22 职场文书
深入理解CSS 中 transform matrix矩阵变换问题
2021/08/30 HTML / CSS
Python可视化学习之matplotlib内置单颜色
2022/02/24 Python
科学家研发出新型速效酶,可在 24 小时内降解塑料制品
2022/04/29 数码科技