pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解


Posted in Python onJanuary 02, 2020

公式

首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。

测试代码(一维)

import torch
import torch.nn as nn
import math

criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = 0
for i in range(1):
  first = -output[i][label[i]]
second = 0
for i in range(1):
  for j in range(5):
    second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

测试代码(多维)

import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)

print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)

first = [0, 0, 0]
for i in range(3):
  first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
  for j in range(5):
    second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
  res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

nn.CrossEntropyLoss()中的计算方法

注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。

在pytorch中,CrossEntropyLoss计算公式为:

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

CrossEntropyLoss带权重的计算公式为(默认weight=None):

pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解

以上这篇pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
TensorFlow实现卷积神经网络CNN
Mar 09 Python
python使用socket创建tcp服务器和客户端
Apr 12 Python
python使用Matplotlib绘制分段函数
Sep 25 Python
PyQt5 QListWidget选择多项并返回的实例
Jun 17 Python
pandas将多个dataframe以多个sheet的形式保存到一个excel文件中
Oct 10 Python
Python中的list与tuple集合区别解析
Oct 12 Python
基于python判断目录或者文件代码实例
Nov 29 Python
JAVA SWT事件四种写法实例解析
Jun 05 Python
让你相见恨晚的十个Python骚操作
Nov 18 Python
python 写一个水果忍者游戏
Jan 13 Python
字典算法实现及操作 --python(实用)
Mar 31 Python
Python中生成随机数据安全性、多功能性、用途和速度方面进行比较
Apr 14 Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
pytorch标签转onehot形式实例
Jan 02 #Python
You might like
PHP 反射机制实现动态代理的代码
2008/10/22 PHP
PHP中防止SQL注入攻击和XSS攻击的两个简单方法
2010/04/15 PHP
PHP实现的连贯操作、链式操作实例
2014/07/08 PHP
PHP高级编程实例:编写守护进程
2014/09/02 PHP
php遍历删除整个目录及文件的方法
2015/03/13 PHP
php遍历解析xml字符串的方法
2016/05/05 PHP
php实现微信支付之退款功能
2018/05/30 PHP
PHP基于openssl实现非对称加密代码实例
2020/06/19 PHP
jquery在IE、FF浏览器的差别详细探讨
2013/04/28 Javascript
用js将内容复制到剪贴板兼容浏览器
2014/03/18 Javascript
教你用jquery实现iframe自适应高度
2014/06/11 Javascript
js+jquery实现图片裁剪功能
2015/01/02 Javascript
JS动态给对象添加事件的简单方法
2016/07/19 Javascript
AngularJS中$http使用的简单介绍
2017/03/17 Javascript
AngularJS实现的JSONP跨域访问数据传输功能详解
2017/07/20 Javascript
微信小程序实现的涂鸦功能示例【附源码下载】
2018/01/12 Javascript
vue-star评星组件开发实例
2018/03/01 Javascript
webgl实现物体描边效果的方法介绍
2019/11/27 Javascript
Node.js API详解之 module模块用法实例分析
2020/05/13 Javascript
简单谈谈Python中函数的可变参数
2016/09/02 Python
python Crypto模块的安装与使用方法
2017/12/21 Python
基于pip install django失败时的解决方法
2018/06/12 Python
Python数据可视化实现正态分布(高斯分布)
2019/08/21 Python
pytorch 实现在预训练模型的 input上增减通道
2020/01/06 Python
Keras使用tensorboard显示训练过程的实例
2020/02/15 Python
Pytorch高阶OP操作where,gather原理
2020/04/30 Python
Python 连接 MySQL 的几种方法
2020/09/09 Python
html5指南-1.html5全局属性(html5 global attributes)深入理解
2013/01/07 HTML / CSS
优秀的导游求职信范文
2014/04/06 职场文书
中学生评语大全
2014/04/18 职场文书
门面房租房协议书
2014/08/20 职场文书
生日答谢词
2015/01/05 职场文书
迟到检讨书
2015/01/26 职场文书
关于颐和园的导游词
2015/01/30 职场文书
幼儿园小班教学反思
2016/03/03 职场文书
检讨书格式
2019/04/25 职场文书