Pytorch损失函数nn.NLLLoss2d()用法说明


Posted in Python onJuly 07, 2020

最近做显著星检测用到了NLL损失函数

对于NLL函数,需要自己计算log和softmax的概率值,然后从才能作为输入

输入 [batch_size, channel , h, w]

Pytorch损失函数nn.NLLLoss2d()用法说明

目标 [batch_size, h, w]

输入的目标矩阵,每个像素必须是类型.举个例子。第一个像素是0,代表着类别属于输入的第1个通道;第二个像素是0,代表着类别属于输入的第0个通道,以此类推。

x = Variable(torch.Tensor([[[1, 2, 1],
       [2, 2, 1],
       [0, 1, 1]],
       [[0, 1, 3],
       [2, 3, 1],
       [0, 0, 1]]]))

x = x.view([1, 2, 3, 3])
print("x输入", x)

这里输入x,并改成[batch_size, channel , h, w]的格式。

soft = nn.Softmax(dim=1)

log_soft = nn.LogSoftmax(dim=1)

然后使用softmax函数计算每个类别的概率,这里dim=1表示从在1维度

上计算,也就是channel维度。logsoftmax是计算完softmax后在计算log值

Pytorch损失函数nn.NLLLoss2d()用法说明

手动计算举个栗子:第一个元素

Pytorch损失函数nn.NLLLoss2d()用法说明

y = Variable(torch.LongTensor([[1, 0, 1],
       [0, 0, 1],
       [1, 1, 1]]))

y = y.view([1, 3, 3])

输入label y,改变成[batch_size, h, w]格式

loss = nn.NLLLoss2d()
out = loss(x, y)
print(out)

输入函数,得到loss=0.7947

来手动计算

第一个label=1,则 loss=-1.3133

第二个label=0, 则loss=-0.3133

.
…
…
loss= -(-1.3133-0.3133-0.1269-0.6931-1.3133-0.6931-0.6931-1.3133-0.6931)/9 =0.7947222222222223

是一致的

注意:这个函数会对每个像素做平均,每个batch也会做平均,这里有9个像素,1个batch_size。

补充知识:PyTorch:NLLLoss2d

我就废话不多说了,大家还是直接看代码吧~

import torch
import torch.nn as nn
from torch import autograd
import torch.nn.functional as F
 
inputs_tensor = torch.FloatTensor([
[[2, 4],
 [1, 2]],
[[5, 3],
 [3, 0]],
[[5, 3],
 [5, 2]],
[[4, 2],
 [3, 2]],
 ])
inputs_tensor = torch.unsqueeze(inputs_tensor,0)
# inputs_tensor = torch.unsqueeze(inputs_tensor,1)
print '--input size(nBatch x nClasses x height x width): ', inputs_tensor.shape
 
targets_tensor = torch.LongTensor([
 [0, 2],
 [2, 3]
])
 
targets_tensor = torch.unsqueeze(targets_tensor,0)
print '--target size(nBatch x height x width): ', targets_tensor.shape
 
inputs_variable = autograd.Variable(inputs_tensor, requires_grad=True)
inputs_variable = F.log_softmax(inputs_variable)
targets_variable = autograd.Variable(targets_tensor)
 
loss = nn.NLLLoss2d()
output = loss(inputs_variable, targets_variable)
print '--NLLLoss2d: {}'.format(output)

以上这篇Pytorch损失函数nn.NLLLoss2d()用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中精确输出JSON浮点数的方法
Apr 18 Python
python删除特定文件的方法
Jul 30 Python
numpy中的高维数组转置实例
Apr 17 Python
Python常见MongoDB数据库操作实例总结
Jul 24 Python
python得到qq句柄,并显示在前台的方法
Oct 14 Python
python版本五子棋的实现代码
Dec 11 Python
python实现文件助手中查看微信撤回消息
Apr 29 Python
浅谈pycharm使用及设置方法
Sep 09 Python
TensorFLow 不同大小图片的TFrecords存取实例
Jan 20 Python
python基本算法之实现归并排序(Merge sort)
Sep 01 Python
pycharm进入时每次都是insert模式的解决方式
Feb 05 Python
Python IO文件管理的具体使用
Mar 20 Python
浅析Python __name__ 是什么
Jul 07 #Python
Pytorch上下采样函数--interpolate用法
Jul 07 #Python
pytorch随机采样操作SubsetRandomSampler()
Jul 07 #Python
pytorch加载自己的图像数据集实例
Jul 07 #Python
keras实现VGG16 CIFAR10数据集方式
Jul 07 #Python
使用darknet框架的imagenet数据分类预训练操作
Jul 07 #Python
Python调用C语言程序方法解析
Jul 07 #Python
You might like
Yii2.0高级框架数据库增删改查的一些操作
2015/11/16 PHP
PHP Static延迟静态绑定用法分析
2016/03/16 PHP
基于PHP实现通过照片获取ip地址
2016/04/26 PHP
php 广告点击统计代码(php+mysql)
2018/02/21 PHP
对联广告js flash激活
2006/10/19 Javascript
JavaScript 学习笔记(十六) js事件
2010/02/01 Javascript
JavaScript动态修改背景颜色的方法
2015/04/16 Javascript
IE下JS保存图片的简单实例
2016/07/15 Javascript
微信小程序 form组件详解
2016/10/25 Javascript
jsTree使用记录实例
2016/12/01 Javascript
使用bootstrap-paginator.js 分页来进行ajax 异步分页请求示例
2017/03/09 Javascript
老生常谈jacascript DOM节点获取
2017/04/17 Javascript
JavaScrpt的面向对象全面解析
2017/05/09 Javascript
VUE2.0 ElementUI2.0表格el-table自适应高度的实现方法
2018/11/28 Javascript
Laravel admin实现消息提醒、播放音频功能
2019/07/10 Javascript
koa中间件核心(koa-compose)源码解读分析
2020/06/15 Javascript
vue引入静态js文件的方法
2020/06/20 Javascript
解决qrcode.js生成二维码时必须定义一个空div的问题
2020/07/09 Javascript
解决vue项目运行npm run serve报错的问题
2020/10/26 Javascript
使用PYTHON接收多播数据的代码
2012/03/01 Python
Python 字典dict使用介绍
2014/11/30 Python
Python浅拷贝与深拷贝用法实例
2015/05/09 Python
在Python程序中操作文件之flush()方法的使用教程
2015/05/24 Python
Python的Django框架中的数据过滤功能
2015/07/17 Python
python使用邻接矩阵构造图代码示例
2017/11/10 Python
Python中捕获键盘的方式详解
2019/03/28 Python
Python 剪绳子的多种思路实现(动态规划和贪心)
2020/02/24 Python
python datetime处理时间小结
2020/04/16 Python
Django ORM 查询表中某列字段值的方法
2020/04/30 Python
Square Off美国/加拿大:世界上最聪明的国际象棋棋盘
2018/12/06 全球购物
函授大学生自我鉴定
2014/02/05 职场文书
酒店管理求职信
2014/06/09 职场文书
村干部四风问题整改措施
2014/09/30 职场文书
元旦联欢晚会主持词
2015/07/01 职场文书
Nginx反向代理及负载均衡如何实现(基于linux)
2021/03/31 Servers
利用Python实现翻译HTML中的文本字符串
2022/06/21 Python