pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解


Posted in Python onJanuary 02, 2020

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Cython 三分钟入门教程
Sep 17 Python
简述Python中的面向对象编程的概念
Apr 27 Python
Mac中Python 3环境下安装scrapy的方法教程
Oct 26 Python
Python实现12306火车票抢票系统
Jul 04 Python
Python使用mongodb保存爬取豆瓣电影的数据过程解析
Aug 14 Python
树莓派安装OpenCV3完整过程的实现
Oct 10 Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 Python
python 常见的反爬虫策略
Sep 27 Python
互斥锁解决 Python 中多线程共享全局变量的问题(推荐)
Sep 28 Python
python实现代码审查自动回复消息
Feb 01 Python
利用Python实时获取steam特惠游戏数据
Jun 25 Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
pytorch标签转onehot形式实例
Jan 02 #Python
You might like
PHP中根据IP地址判断城市实现城市切换或跳转代码
2012/09/04 PHP
php+mysql数据库实现无限分类的方法
2014/12/12 PHP
php获取错误信息的方法
2015/07/17 PHP
Laravel中Facade的加载过程与原理详解
2017/09/22 PHP
php实现微信发红包功能
2018/07/13 PHP
JS URL传中文参数引发的乱码问题
2009/09/02 Javascript
两个比较有用的Javascript工具函数代码
2010/02/17 Javascript
设置jQueryUI DatePicker默认语言为中文
2016/06/04 Javascript
vue 2.0封装model组件的方法
2017/08/03 Javascript
vue首次赋值不触发watch的解决方法
2018/09/11 Javascript
详解Node.js中path模块的resolve()和join()方法的区别
2018/10/29 Javascript
浅谈vue3中effect与computed的亲密关系
2019/10/10 Javascript
js实现文字头像的生成代码
2020/03/07 Javascript
Python中List.count()方法的使用教程
2015/05/20 Python
深入解析Python编程中JSON模块的使用
2015/10/15 Python
windows下ipython的安装与使用详解
2016/10/20 Python
pandas 对每一列数据进行标准化的方法
2018/06/09 Python
Numpy之random函数使用学习
2019/01/29 Python
Python中format()格式输出全解
2019/04/12 Python
python logging模块的使用总结
2019/07/09 Python
python爬虫开发之Beautiful Soup模块从安装到详细使用方法与实例
2020/03/09 Python
Python面向对象程序设计之继承、多态原理与用法详解
2020/03/23 Python
python正则表达式 匹配反斜杠的操作方法
2020/08/07 Python
html5各种页面切换效果和模态对话框用法总结
2014/12/15 HTML / CSS
德国奢侈品网上商城:Mytheresa
2016/08/24 全球购物
中国最大的名表商城:万表网
2016/08/29 全球购物
在C#中如何实现多态
2014/07/02 面试题
企业统计员岗位职责
2013/12/13 职场文书
高中生家长寄语大全
2014/04/03 职场文书
购房协议书
2014/04/11 职场文书
技术比武方案
2014/05/19 职场文书
小学竞选班长演讲稿
2014/09/09 职场文书
个人授权委托书范本格式
2014/10/12 职场文书
2014年班组工作总结
2014/11/20 职场文书
2016年第32个教师节红领巾广播稿
2015/12/18 职场文书
2019年二手房买卖合同范本
2019/10/14 职场文书