pytorch交叉熵损失函数的weight参数的使用


Posted in Python onMay 24, 2021

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。

class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

pytorch交叉熵损失函数的weight参数的使用

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python抓取网页图片示例(python爬虫)
Apr 27 Python
Python中用Spark模块的使用教程
Apr 13 Python
在Python的Django框架中加载模版的方法
Jul 16 Python
python从入门到精通(DAY 2)
Dec 20 Python
Python实现的破解字符串找茬游戏算法示例
Sep 25 Python
python实现一个简单的ping工具方法
Jan 31 Python
Python django框架应用中实现获取访问者ip地址示例
May 17 Python
python自动化之Ansible的安装教程
Jun 13 Python
Python实现TCP通信的示例代码
Sep 09 Python
Python线程条件变量Condition原理解析
Jan 20 Python
Python实现PIL图像处理库绘制国际象棋棋盘
Jul 16 Python
python神经网络 使用Keras构建RNN训练
May 04 Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
You might like
php 文件状态缓存带来的问题
2008/12/14 PHP
PHPMailer发送HTML内容、带附件的邮件实例
2014/07/01 PHP
php中print(),print_r(),echo()的区别详解
2014/12/01 PHP
Laravel使用RabbitMQ的方法示例
2019/06/18 PHP
浅析JQuery获取和设置Select选项的常用方法总结
2013/07/04 Javascript
JS(JQuery)操作Array的相关方法介绍
2014/02/11 Javascript
node.js中的fs.futimesSync方法使用说明
2014/12/17 Javascript
jQuery及JS实现循环中暂停的方法
2015/02/02 Javascript
javascript 闭包详解
2015/07/02 Javascript
javascript从定义到执行 你不知道的那些事
2016/01/04 Javascript
js实现页面a向页面b传参的方法
2016/05/29 Javascript
HTML页面定时跳转方法解析(2种任选)
2016/12/22 Javascript
详解JS对象封装的常用方式
2016/12/30 Javascript
jQuery插件autocomplete使用详解
2017/02/04 Javascript
jquery ui sortable拖拽后保存位置
2017/04/27 jQuery
Nodejs搭建wss服务器教程
2017/05/24 NodeJs
jquery实现放大镜简洁代码(推荐)
2017/06/08 jQuery
AngularJS实现的自定义过滤器简单示例
2019/02/02 Javascript
Vue如何基于es6导入外部js文件
2020/05/15 Javascript
python 迭代器和iter()函数详解及实例
2017/03/21 Python
python3操作微信itchat实现发送图片
2018/02/24 Python
java中两个byte数组实现合并的示例
2018/05/09 Python
完美解决Pycharm无法导入包的问题 Unresolved reference
2018/05/18 Python
详解Python3的TFTP文件传输
2018/06/26 Python
浅谈Python中函数的定义及其调用方法
2019/07/19 Python
Python Pandas数据中对时间的操作
2019/07/30 Python
python 实现二维字典的键值合并等函数
2019/12/06 Python
使用jupyter notebook将文件保存为Markdown,HTML等文件格式
2020/04/14 Python
《小池塘》教学反思
2014/02/28 职场文书
理想演讲稿范文
2014/05/21 职场文书
工地宣传标语
2014/06/18 职场文书
文明班级申报材料
2014/12/24 职场文书
南湾猴岛导游词
2015/02/09 职场文书
单位收入证明范本
2015/06/18 职场文书
八年级数学教学反思
2016/02/17 职场文书
《分一些蚊子进来》读后感3篇
2020/01/09 职场文书