pytorch交叉熵损失函数的weight参数的使用


Posted in Python onMay 24, 2021

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。

class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

pytorch交叉熵损失函数的weight参数的使用

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 函数传参之传值还是传引用的分析
Sep 07 Python
Python cookbook(数据结构与算法)通过公共键对字典列表排序算法示例
Mar 15 Python
Python+OpenCV实现车牌字符分割和识别
Mar 31 Python
python numpy 部分排序 寻找最大的前几个数的方法
Jun 27 Python
Python自定义装饰器原理与用法实例分析
Jul 16 Python
python批量赋值操作实例
Oct 22 Python
Python的垃圾回收机制详解
Aug 28 Python
python实现ftp文件传输功能
Mar 20 Python
Python xpath表达式如何实现数据处理
Jun 13 Python
python实现感知机模型的示例
Sep 30 Python
Python图像识别+KNN求解数独的实现
Nov 13 Python
Python使用sql语句对mysql数据库多条件模糊查询的思路详解
Apr 12 Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
You might like
php学习笔记 面向对象的构造与析构方法
2011/06/13 PHP
谈谈从phpinfo中能获取哪些值得注意的信息
2017/03/28 PHP
php 输出缓冲 Output Control用法实例详解
2020/03/03 PHP
用JQuery 实现的自定义对话框
2007/03/24 Javascript
jquery写个checkbox——类似邮箱全选功能
2013/03/19 Javascript
javascript闭包的高级使用方法实例
2013/07/04 Javascript
js常用自定义公共函数汇总
2014/01/15 Javascript
JavaScript监听文本框回车事件并过滤文本框空格的方法
2015/04/16 Javascript
动态创建按钮的JavaScript代码
2016/01/29 Javascript
jQuery添加和删除输入文本框标签代码
2016/05/20 Javascript
归纳下js面向对象的几种常见写法总结
2016/08/24 Javascript
微信小程序教程之本地图片上传(leancloud)实例详解
2016/11/16 Javascript
JS图片轮播与索引变色功能实例详解
2017/07/06 Javascript
js数组实现权重概率分配
2017/09/12 Javascript
微信小程序wx.getImageInfo()如何获取图片信息
2018/01/26 Javascript
vue路由懒加载的实现方法
2018/03/12 Javascript
JS实现DOM删除节点操作示例
2018/04/04 Javascript
手写Vue2.0 数据劫持的示例
2021/03/04 Vue.js
Python中列表和元组的相关语句和方法讲解
2015/08/20 Python
更改Ubuntu默认python版本的两种方法python-> Anaconda
2016/12/18 Python
Python数据结构与算法之列表(链表,linked list)简单实现
2017/10/30 Python
Python3网络爬虫中的requests高级用法详解
2019/06/18 Python
python中的global关键字的使用方法
2019/08/20 Python
使用python-opencv读取视频,计算视频总帧数及FPS的实现
2019/12/10 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
2020/01/21 Python
详解Canvas实用库Fabric.js使用手册
2019/01/07 HTML / CSS
美国羽绒床上用品第一品牌:Pacific Coast
2018/08/25 全球购物
波兰最早的运动鞋精品店之一:Street Supply
2019/08/29 全球购物
实习销售业务员自我鉴定
2013/09/21 职场文书
公司行政经理岗位职责
2013/12/24 职场文书
社区学雷锋活动策划方案
2014/01/30 职场文书
会计电算化应届生自荐信
2014/02/25 职场文书
学习经验交流会主持词
2014/04/01 职场文书
2015年教育实习工作总结
2015/04/24 职场文书
浅谈JavaScript浅拷贝和深拷贝
2021/11/07 Javascript
Java获取字符串编码格式实现思路
2022/09/23 Java/Android