pytorch交叉熵损失函数的weight参数的使用


Posted in Python onMay 24, 2021

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。

class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

pytorch交叉熵损失函数的weight参数的使用

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用中文的方法
Feb 19 Python
基于python select.select模块通信的实例讲解
Sep 21 Python
Python自动化运维之IP地址处理模块详解
Dec 10 Python
python语音识别实践之百度语音API
Aug 30 Python
使用python socket分发大文件的实现方法
Jul 08 Python
解决pandas展示数据输出时列名不能对齐的问题
Nov 18 Python
Python算法中的时间复杂度问题
Nov 19 Python
Python 调用有道翻译接口实现翻译
Mar 02 Python
Python项目跨域问题解决方案
Jun 22 Python
Python上下文管理器Content Manager
Jun 26 Python
python中pd.cut()与pd.qcut()的对比及示例
Jun 16 Python
python语言中pandas字符串分割str.split()函数
Aug 05 Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
You might like
无线电广播的开始
2002/01/30 无线电
在PHP的图形函数中显示汉字
2006/10/09 PHP
怎样在UNIX系统下安装php3
2006/10/09 PHP
php error_log 函数的使用
2009/04/13 PHP
PHP文件操作实现代码分享
2011/09/01 PHP
PHP关于htmlspecialchars、strip_tags、addslashes的解释
2014/07/04 PHP
PHP+APACHE实现网址伪静态
2015/02/22 PHP
PHP中递归的实现实例详解
2017/11/14 PHP
ie 调试javascript的工具
2009/04/29 Javascript
你需要知道的10个最佳javascript开发实践小结
2012/04/15 Javascript
jQuery拖动图片删除示例
2013/05/10 Javascript
JavaScript中使用Callback控制流程介绍
2015/03/16 Javascript
javascript实现图片跟随鼠标移动效果的方法
2015/05/13 Javascript
基于MVC4+EasyUI的Web开发框架形成之旅之界面控件的使用
2015/12/16 Javascript
基于 Node.js 实现前后端分离
2016/04/23 Javascript
Javascript点击按钮随机改变数字与其颜色
2016/09/01 Javascript
jquery-mobile基础属性与用法详解
2016/11/23 Javascript
JavaScript基于activexobject连接远程数据库SQL Server 2014的方法
2017/07/12 Javascript
AngularJS 仿微信图片手势缩放的实例
2017/09/28 Javascript
webpack-url-loader 解决项目中图片打包路径问题
2019/02/15 Javascript
多个vue子路由文件自动化合并的方法
2019/09/03 Javascript
vue 组件内获取actions的response方式
2019/11/08 Javascript
Bootstrap实现前端登录页面带验证码功能完整示例
2020/03/26 Javascript
vue与iframe之间的信息交互的实现
2020/04/08 Javascript
解决vue 使用setTimeout,离开当前路由setTimeout未销毁的问题
2020/07/21 Javascript
[01:06]DOTA2隆重推出2016冬季勇士令状 内含上海特级锦标赛互动指南
2016/02/17 DOTA
[05:04]DOTA2上海特级锦标赛主赛事第二日TOP10
2016/03/04 DOTA
python中用logging实现日志滚动和过期日志删除功能
2019/08/20 Python
Python数据持久化存储实现方法分析
2019/12/21 Python
Python Spyder 调出缩进对齐线的操作
2021/02/26 Python
6PM官网:折扣鞋、服装及配饰
2018/08/03 全球购物
英国女性时尚品牌:Apricot
2018/12/04 全球购物
意大利辅助药品、药物和补品在线销售:FarmaEurope
2020/04/29 全球购物
大学生村官座谈会发言材料
2014/05/25 职场文书
2016高考冲刺决心书
2015/09/23 职场文书
初中化学教学反思
2016/02/22 职场文书