Keras中的多分类损失函数用法categorical_crossentropy


Posted in Python onJune 11, 2020

from keras.utils.np_utils import to_categorical

注意:当使用categorical_crossentropy损失函数时,你的标签应为多类模式,例如如果你有10个类别,每一个样本的标签应该是一个10维的向量,该向量在对应有值的索引位置为1其余为0。

可以使用这个方法进行转换:

from keras.utils.np_utils import to_categorical
categorical_labels = to_categorical(int_labels, num_classes=None)

以mnist数据集为例:

from keras.datasets import mnist

(X_train, y_train), (X_test, y_test) = mnist.load_data()
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

...
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(X_train, y_train, epochs=100, batch_size=1, verbose=2)

补充知识:Keras中损失函数binary_crossentropy和categorical_crossentropy产生不同结果的分析

问题

在使用keras做对心电信号分类的项目中发现一个问题,这个问题起源于我的一个使用错误:

binary_crossentropy 二进制交叉熵用于二分类问题中,categorical_crossentropy分类交叉熵适用于多分类问题中,我的心电分类是一个多分类问题,但是我起初使用了二进制交叉熵,代码如下所示:

sgd = SGD(lr=0.003, decay=0, momentum=0.7, nesterov=False)
model.compile(loss='categorical_crossentropy',
  optimizer='sgd',metrics=['accuracy'])
model.fit(X_train, Y_train, validation_data=(X_test,Y_test),batch_size=16, epochs=20)
score = model.evaluate(X_test, Y_test, batch_size=16)

注意:我的CNN网络模型在最后输入层正确使用了应该用于多分类问题的softmax激活函数

后来我在另一个残差网络模型中对同类数据进行相同的分类问题中,正确使用了分类交叉熵,令人奇怪的是残差模型的效果远弱于普通卷积神经网络,这一点是不符合常理的,经过多次修改分析终于发现可能是损失函数的问题,因此我使用二进制交叉熵在残差网络中,终于取得了优于普通卷积神经网络的效果。

因此可以断定问题就出在所使用的损失函数身上

原理

本人也只是个只会使用框架的调参侠,对于一些原理也是一知半解,经过了学习才大致明白,将一些原理记录如下:

要搞明白分类熵和二进制交叉熵先要从二者适用的激活函数说起

激活函数

sigmoid, softmax主要用于神经网络输出层的输出。

softmax函数

Keras中的多分类损失函数用法categorical_crossentropy

softmax可以看作是Sigmoid的一般情况,用于多分类问题。

Softmax函数将K维的实数向量压缩(映射)成另一个K维的实数向量,其中向量中的每个元素取值都介于 (0,1) 之间。常用于多分类问题。

sigmoid函数

Keras中的多分类损失函数用法categorical_crossentropy

Sigmoid 将一个实数映射到 (0,1) 的区间,可以用来做二分类。Sigmoid 在特征相差比较复杂或是相差不是特别大时效果比较好。Sigmoid不适合用在神经网络的中间层,因为对于深层网络,sigmoid 函数反向传播时,很容易就会出现梯度消失的情况(在 sigmoid 接近饱和区时,变换太缓慢,导数趋于 0,这种情况会造成信息丢失),从而无法完成深层网络的训练。所以Sigmoid主要用于对神经网络输出层的激活。

分析

所以说多分类问题是要softmax激活函数配合分类交叉熵函数使用,而二分类问题要使用sigmoid激活函数配合二进制交叉熵函数适用,但是如果在多分类问题中使用了二进制交叉熵函数最后的模型分类效果会虚高,即比模型本身真实的分类效果好。

所以就会出现我遇到的情况,这里引用了论坛一位大佬的样例:

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # WRONG way

model.fit(x_train, y_train,
   batch_size=batch_size,
   epochs=2, # only 2 epochs, for demonstration purposes
   verbose=1,
   validation_data=(x_test, y_test))

# Keras reported accuracy:
score = model.evaluate(x_test, y_test, verbose=0) 
score[1]
# 0.9975801164627075

# Actual accuracy calculated manually:
import numpy as np
y_pred = model.predict(x_test)
acc = sum([np.argmax(y_test[i])==np.argmax(y_pred[i]) for i in range(10000)])/10000
acc
# 0.98780000000000001

score[1]==acc
# False

样例中模型在评估中得到的准确度高于实际测算得到的准确度,网上给出的原因是Keras没有定义一个准确的度量,但有几个不同的,比如binary_accuracy和categorical_accuracy,当你使用binary_crossentropy时keras默认在评估过程中使用了binary_accuracy,但是针对你的分类要求,应当采用的是categorical_accuracy,所以就造成了这个问题(其中的具体原理我也没去看源码详细了解)

解决

所以问题最后的解决方法就是:

对于多分类问题,要么采用

from keras.metrics import categorical_accuracy
model.compile(loss='binary_crossentropy', 
 optimizer='adam', metrics=[categorical_accuracy])

要么采用

model.compile(loss='categorical_crossentropy',
optimizer='adam',metrics=['accuracy'])

以上这篇Keras中的多分类损失函数用法categorical_crossentropy就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python异常学习笔记
Feb 03 Python
python实现查找两个字符串中相同字符并输出的方法
Jul 11 Python
python实现报表自动化详解
Nov 16 Python
使用Python通过win32 COM实现Word文档的写入与保存方法
May 08 Python
Vue的el-scrollbar实现自定义滚动
May 29 Python
Python实现决策树C4.5算法的示例
May 30 Python
Python3用tkinter和PIL实现看图工具
Jun 21 Python
Python拼接微信好友头像大图的实现方法
Aug 01 Python
Python实现正则表达式匹配任意的邮箱方法
Dec 20 Python
Django:使用filter的pk进行多值查询操作
Jul 15 Python
我对PyTorch dataloader里的shuffle=True的理解
May 20 Python
Python+DeOldify实现老照片上色功能
Jun 21 Python
Python 列表中的修改、添加和删除元素的实现
Jun 11 #Python
python中什么是面向对象
Jun 11 #Python
python实现凯撒密码、凯撒加解密算法
Jun 11 #Python
python新手学习可变和不可变对象
Jun 11 #Python
基于Keras 循环训练模型跑数据时内存泄漏的解决方式
Jun 11 #Python
什么是python的id函数
Jun 11 #Python
Keras:Unet网络实现多类语义分割方式
Jun 11 #Python
You might like
PHP查询网站的PR值
2013/10/30 PHP
php数组去除空值函数分享
2015/02/02 PHP
从Ajax到JQuery Ajax学习
2007/02/14 Javascript
使用CSS3的scale实现网页整体缩放
2014/03/18 Javascript
jQuery toggleClass应用实例(附效果图)
2014/04/06 Javascript
JavaScript获取URL汇总
2015/06/08 Javascript
JavaScript保存并运算页面中数字类型变量的写法
2015/07/06 Javascript
使用React实现轮播效果组件示例代码
2016/09/05 Javascript
快速解决js开发下拉框中blur与click冲突
2016/10/10 Javascript
js实现对table的增加行和删除行的操作方法
2016/10/13 Javascript
jquery pagination分页插件使用详解(后台struts2)
2017/01/22 Javascript
d3.js中冷门却实用的内置函数总结
2017/02/04 Javascript
bootstrap table表格客户端分页实例
2017/08/07 Javascript
react高阶组件经典应用之权限控制详解
2017/09/07 Javascript
vue+springmvc导出excel数据的实现代码
2018/06/27 Javascript
vue的style绑定background-image的方式和其他变量数据的区别详解
2018/09/03 Javascript
微信小程序时间控件picker view使用详解
2018/12/28 Javascript
Vue图片浏览组件v-viewer用法分析【支持旋转、缩放、翻转等操作】
2019/11/04 Javascript
[01:39:42]Fnatic vs Mineski 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/18 DOTA
python类型强制转换long to int的代码
2013/02/10 Python
Python实现的数据结构与算法之链表详解
2015/04/22 Python
Python创建二维数组实例(关于list的一个小坑)
2017/11/07 Python
python中format()函数的简单使用教程
2018/03/14 Python
Python logging模块用法示例
2018/08/28 Python
PyQt5 实现字体大小自适应分辨率的方法
2019/06/18 Python
Python 实现二叉查找树的示例代码
2020/12/21 Python
html5自动播放mov格式视频的实例代码
2020/01/14 HTML / CSS
elf彩妆英国官网:e.l.f. Cosmetics英国(美国平价彩妆品牌)
2017/11/02 全球购物
亿企通软件测试面试题
2012/04/10 面试题
计算机网络专业个人的自我评价
2013/10/17 职场文书
写自荐信的注意事项
2014/03/09 职场文书
cf收人广告词大全
2014/03/14 职场文书
2014年党的群众路线教育实践活动总结
2014/04/25 职场文书
农村党员一句话承诺
2014/05/30 职场文书
2014年计划生育协会工作总结
2014/11/14 职场文书
配置Kubernetes外网访问集群
2022/03/31 Servers