pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解


Posted in Python onJanuary 02, 2020

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python调用短信猫控件实现发短信功能实例
Jul 04 Python
python对数组进行反转的方法
May 20 Python
python logging 日志轮转文件不删除问题的解决方法
Aug 02 Python
python利用matplotlib库绘制饼图的方法示例
Dec 18 Python
pandas数据分组和聚合操作方法
Apr 11 Python
朴素贝叶斯Python实例及解析
Nov 19 Python
Python 加密与解密小结
Dec 06 Python
Python面向对象程序设计之私有属性及私有方法示例
Apr 08 Python
Python产生一个数值范围内的不重复的随机数的实现方法
Aug 21 Python
对python中list的五种查找方法说明
Jul 13 Python
最简单的matplotlib安装教程(小白)
Jul 28 Python
Python使用paramiko连接远程服务器执行Shell命令的实现
Mar 04 Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
pytorch标签转onehot形式实例
Jan 02 #Python
You might like
深入apache配置文件httpd.conf的部分参数说明
2013/06/28 PHP
PHP SPL标准库中的常用函数介绍
2015/05/11 PHP
PHP关键特性之命名空间实例详解
2017/05/06 PHP
TP5框架实现上传多张图片的方法分析
2020/03/29 PHP
PHP dirname(__FILE__)原理及用法解析
2020/10/28 PHP
JS 页面内容搜索,类似于 Ctrl+F功能的实现代码
2007/08/13 Javascript
使Ext的Template可以解析二层的json数据的方法
2007/12/22 Javascript
基于JQuery.timer插件实现一个计时器
2010/04/25 Javascript
用js获取电脑信息(是使用与IE浏览器)
2013/01/15 Javascript
怎么清空javascript数组
2013/05/11 Javascript
jquery 文本上下无缝滚动,鼠标放上去就停止 小例子
2013/06/05 Javascript
jQuery中.live()方法的用法深入解析
2013/12/30 Javascript
jquery对象与DOM对象转化
2017/02/08 Javascript
Node.js如何实现注册邮箱激活功能 (常见)
2017/07/23 Javascript
jQuery表单校验插件validator使用方法详解
2020/02/18 jQuery
javascript设计模式 ? 工厂模式原理与应用实例分析
2020/04/09 Javascript
Python入门篇之字符串
2014/10/17 Python
python Matplotlib画图之调整字体大小的示例
2017/11/20 Python
Django 生成登陆验证码代码分享
2017/12/12 Python
python实现决策树、随机森林的简单原理
2018/03/26 Python
python创造虚拟环境方法总结
2019/03/04 Python
Python检测数据类型的方法总结
2019/05/20 Python
python 实现返回一个列表中出现次数最多的元素方法
2019/06/11 Python
python用类实现文章敏感词的过滤方法示例
2019/10/27 Python
python能做哪些生活有趣的事情
2020/09/09 Python
CSS3 实现时间轴动画
2020/11/25 HTML / CSS
Brasty波兰:香水、化妆品、手表网上商店
2019/04/15 全球购物
广州某公司软件工程师面试题
2014/12/22 面试题
学生自我评价范文
2014/02/02 职场文书
助理政工师申报材料
2014/06/03 职场文书
大学毕业典礼演讲稿
2014/09/09 职场文书
党风廉正建设责任书
2015/01/29 职场文书
交流会主持词
2015/07/02 职场文书
2016年小学六一儿童节活动总结
2016/04/06 职场文书
用Python可视化新冠疫情数据
2022/01/18 Python
python数据处理之Pandas类型转换
2022/04/28 Python