pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解


Posted in Python onJanuary 02, 2020

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过ftplib登录到ftp服务器的方法
May 08 Python
Django实战之用户认证(用户登录与注销)
Jul 16 Python
python使用Plotly绘图工具绘制柱状图
Apr 01 Python
django认证系统 Authentication使用详解
Jul 22 Python
Python 寻找局部最高点的实现
Dec 05 Python
简单了解Django ORM常用字段类型及参数配置
Jan 07 Python
django3.02模板中的超链接配置实例代码
Feb 04 Python
python+requests接口压力测试500次,查看响应时间的实例
Apr 30 Python
python numpy库np.percentile用法说明
Jun 08 Python
Python的3种运行方式:命令行窗口、Python解释器、IDLE的实现
Oct 10 Python
DRF使用simple JWT身份验证的实现
Jan 14 Python
Pygame Event事件模块的详细示例
Nov 17 Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
pytorch标签转onehot形式实例
Jan 02 #Python
You might like
destoon数据库表说明汇总
2014/07/15 PHP
destoon实现调用图文新闻的方法
2014/08/21 PHP
ThinkPHP内置jsonRPC的缺陷分析
2014/12/18 PHP
php检索或者复制远程文件的方法
2015/03/13 PHP
PHP的mysqli_rollback()函数讲解
2019/01/23 PHP
Javascript !!的作用
2008/12/04 Javascript
tangram框架响应式加载图片方法
2013/11/21 Javascript
ajax提交表单实现网页无刷新注册示例
2014/05/08 Javascript
js 左右悬浮对联广告代码示例
2014/12/12 Javascript
JavaScript字符串常用类使用方法汇总
2015/04/14 Javascript
jQuery实现限制textarea文本框输入字符数量的方法
2015/05/28 Javascript
JavaScript实现Fly Bird小游戏
2016/12/15 Javascript
深入理解JavaScript中的for循环
2017/02/07 Javascript
Angular4 ElementRef的应用
2018/02/26 Javascript
浅析JavaScript异步代码优化
2019/03/18 Javascript
vue2.0自定义指令示例代码详解
2019/04/25 Javascript
vue-cli3 项目优化之通过 node 自动生成组件模板 generate View、Component
2019/04/30 Javascript
浅谈Three.js截图并下载的大坑
2019/11/01 Javascript
基于vue3.0.1beta搭建仿京东的电商H5项目
2020/05/06 Javascript
python将.ppm格式图片转换成.jpg格式文件的方法
2018/10/27 Python
详解python做UI界面的方法
2019/02/27 Python
浅谈tensorflow之内存暴涨问题
2020/02/05 Python
Python使用itcaht库实现微信自动收发消息功能
2020/07/13 Python
详解CSS3的opacity属性设置透明效果的用法
2016/05/09 HTML / CSS
加拿大建筑和装修专家:Reno-Depot
2017/12/21 全球购物
LivingSocial英国:英国本地优惠
2019/02/22 全球购物
GetYourGuide台湾:预订旅游活动、景点和旅游项目
2019/06/10 全球购物
四年大学生活的自我评价范文
2014/02/07 职场文书
酒店营销策划方案
2014/02/07 职场文书
党员公开承诺书
2014/03/25 职场文书
关于感恩的演讲稿200字
2014/08/26 职场文书
总经理助理岗位职责
2015/01/31 职场文书
不同意离婚答辩状
2015/05/22 职场文书
python 爬取豆瓣网页的示例
2021/04/13 Python
Python 数据可视化之Matplotlib详解
2021/11/02 Python
CSS三大特性继承性、层叠性和优先级详解
2022/01/18 HTML / CSS