pytorch交叉熵损失函数的weight参数的使用


Posted in Python onMay 24, 2021

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。

class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

pytorch交叉熵损失函数的weight参数的使用

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python找出9个连续的空闲端口
Feb 01 Python
Python3中类、模块、错误与异常、文件的简易教程
Nov 20 Python
Python实现七彩蟒蛇绘制实例代码
Jan 16 Python
python3.5+tesseract+adb实现西瓜视频或头脑王者辅助答题
Jan 17 Python
运行django项目指定IP和端口的方法
May 14 Python
对pandas中to_dict的用法详解
Jun 05 Python
Python中字符串与编码示例代码
May 20 Python
详解python深浅拷贝区别
Jun 24 Python
python中PS 图像调整算法原理之亮度调整
Jun 28 Python
python 列表推导式使用详解
Aug 29 Python
浅谈对python中if、elif、else的误解
Aug 20 Python
Python基本数据类型之字符串str
Jul 21 Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
You might like
PHP的变量总结 新手推荐
2011/04/18 PHP
ThinkPHP中I(),U(),$this->post()等函数用法
2014/11/22 PHP
php求一个网段开始与结束IP地址的方法
2015/07/09 PHP
PHP反射API示例分享
2016/10/08 PHP
Js 回车换行处理的办法及replace方法应用
2013/01/24 Javascript
javascrip关于继承的小例子
2013/05/10 Javascript
javascript 得到文件后缀名的思路及实现
2020/05/09 Javascript
限制只能输入数字的实现代码
2016/05/16 Javascript
使用jQuery的load方法设计动态加载及解决被加载页面js失效问题
2017/03/01 Javascript
详解Vue方法与事件
2017/03/09 Javascript
基于vue的fullpage.js单页滚动插件
2017/03/20 Javascript
详解Node.js中exports和module.exports的区别
2017/04/19 Javascript
angular中实现控制器之间传递参数的方式
2017/04/24 Javascript
angular select 默认值设置方法
2017/06/23 Javascript
Vue.js项目部署到服务器的详细步骤
2017/07/17 Javascript
详解Webstorm 新建.vue文件支持高亮vue语法和es6语法
2017/10/26 Javascript
详解给Vue2路由导航钩子和axios拦截器做个封装
2018/04/10 Javascript
vue弹窗组件使用方法
2018/04/28 Javascript
Vue中的混入的使用(vue mixins)
2018/06/01 Javascript
js实现图片上传并预览功能
2018/08/06 Javascript
Vue中消息横向滚动时setInterval清不掉的问题及解决方法
2019/08/23 Javascript
jQuery实现的解析本地 XML 文档操作示例
2020/04/30 jQuery
python的urllib模块显示下载进度示例
2014/01/17 Python
使用Python读写及压缩和解压缩文件的示例
2016/07/08 Python
python的dataframe转换为多维矩阵的方法
2018/04/11 Python
python实现在图片上画特定大小角度矩形框
2018/10/24 Python
windows下numpy下载与安装图文教程
2019/04/02 Python
Python图像处理PIL各模块详细介绍(推荐)
2019/07/17 Python
python 爬虫如何实现百度翻译
2020/11/16 Python
BeautifulSoup中find和find_all的使用详解
2020/12/07 Python
python 如何在测试中使用 Mock
2021/03/01 Python
css3一个简易的 LED 数字时钟实现方法
2020/01/15 HTML / CSS
HTML5新特性之type=file文件上传功能
2018/02/02 HTML / CSS
输入N,打印N*N矩阵
2012/02/20 面试题
广播稿:校园广播稿范文
2019/04/17 职场文书
使用Canvas绘制一个游戏人物属性图
2022/03/25 Javascript