pytorch实现focal loss的两种方式小结


Posted in Python onJanuary 02, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
'''
pytorch实现focal loss的两种方式(现在讨论的是基于分割任务)
在计算损失函数的过程中考虑到类别不平衡的问题,假设加上背景类别共有6个类别
'''
def compute_class_weights(histogram):
  classWeights = np.ones(6, dtype=np.float32)
  normHist = histogram / np.sum(histogram)
  for i in range(6):
    classWeights[i] = 1 / (np.log(1.10 + normHist[i]))
  return classWeights
def focal_loss_my(input,target):
  '''
  :param input: shape [batch_size,num_classes,H,W] 仅仅经过卷积操作后的输出,并没有经过任何激活函数的作用
  :param target: shape [batch_size,H,W]
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)
  '''
  根据当前给出的ground truth label计算出每个类别所占据的权重
  '''

  # weights=torch.from_numpy(classWeights).float().cuda()
  weights = torch.from_numpy(classWeights).float()
  focal_frequency = F.nll_loss(F.softmax(input, dim=1), target, reduction='none')
  '''
  上面一篇博文讲过
  F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同
  可见F.nll_loss中实现了对于target的one-hot encoding编码功能,将其编码成与input shape相同的tensor
  然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production
  相当于取出了 log(p_gt)即当前样本点被分类为正确类别的概率
  现在去掉取log的操作,相当于 focal_frequency shape [num_samples]
  即取出ground truth类别的概率数值,并取了负号
  '''

  focal_frequency += 1.0#shape [num_samples] 1-P(gt_classes)

  focal_frequency = torch.pow(focal_frequency, 2) # torch.Size([75])
  focal_frequency = focal_frequency.repeat(c, 1)
  '''
  进行repeat操作后,focal_frequency shape [num_classes,num_samples]
  '''
  focal_frequency = focal_frequency.transpose(1, 0)
  loss = F.nll_loss(focal_frequency * (torch.log(F.softmax(input, dim=1))), target, weight=None,
           reduction='elementwise_mean')
  return loss


def focal_loss_zhihu(input, target):
  '''
  :param input: 使用知乎上面大神给出的方案 https://zhuanlan.zhihu.com/p/28527749
  :param target:
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  inputs = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  N = inputs.size(0)
  C = inputs.size(1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)

  weights = torch.from_numpy(classWeights).float()
  weights=weights[target.view(-1)]#这行代码非常重要

  gamma = 2

  P = F.softmax(inputs, dim=1)#shape [num_samples,num_classes]

  class_mask = inputs.data.new(N, C).fill_(0)
  class_mask = Variable(class_mask)
  ids = target.view(-1, 1)
  class_mask.scatter_(1, ids.data, 1.)#shape [num_samples,num_classes] one-hot encoding

  probs = (P * class_mask).sum(1).view(-1, 1)#shape [num_samples,]
  log_p = probs.log()

  print('in calculating batch_loss',weights.shape,probs.shape,log_p.shape)

  # batch_loss = -weights * (torch.pow((1 - probs), gamma)) * log_p
  batch_loss = -(torch.pow((1 - probs), gamma)) * log_p

  print(batch_loss.shape)

  loss = batch_loss.mean()
  return loss

if __name__=='__main__':
  pred=torch.rand((2,6,5,5))
  y=torch.from_numpy(np.random.randint(0,6,(2,5,5)))
  loss1=focal_loss_my(pred,y)
  loss2=focal_loss_zhihu(pred,y)

  print('loss1',loss1)
  print('loss2', loss2)
'''
in calculating batch_loss torch.Size([50]) torch.Size([50, 1]) torch.Size([50, 1])
torch.Size([50, 1])
loss1 tensor(1.3166)
loss2 tensor(1.3166)
'''

以上这篇pytorch实现focal loss的两种方式小结就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python下如何让web元素的生成更简单的分析
Jul 17 Python
Python编程生成随机用户名及密码的方法示例
May 05 Python
Python格式化输出%s和%d
May 07 Python
python和shell监控linux服务器的详细代码
Jun 22 Python
python爬虫简单的添加代理进行访问的实现代码
Apr 04 Python
python 字典操作提取key,value的方法
Jun 26 Python
python爬虫添加请求头代码实例
Dec 28 Python
彻底搞懂 python 中文乱码问题(深入分析)
Feb 28 Python
Python3爬虫关于识别点触点选验证码的实例讲解
Jul 30 Python
基于Python实现全自动下载抖音视频
Nov 06 Python
python3判断IP地址的方法
Mar 04 Python
Python编程源码报错解决方法总结经验分享
Oct 05 Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
You might like
强烈推荐:php.ini中文版(1)
2006/10/09 PHP
PHP中判断变量为空的几种方法分享
2013/08/26 PHP
解析JavaScript中instanceof对于不同的构造器或许都返回true
2013/12/03 Javascript
javascript中使用new与不使用实例化对象的区别
2015/06/22 Javascript
jquery判断当前浏览器的实现代码
2015/11/07 Javascript
教你如何终止JQUERY的$.AJAX请求
2016/02/23 Javascript
AngularJS入门教程之XHR和依赖注入详解
2016/08/18 Javascript
AngularJS折叠菜单实现方法示例
2017/05/18 Javascript
vue2.0之多页面的开发的示例
2018/01/30 Javascript
原生JS实现瀑布流插件
2018/02/06 Javascript
在vue2.0中引用element-ui组件库的方法
2018/06/21 Javascript
RequireJS用法简单示例
2018/08/20 Javascript
详解vue2.6插槽更新v-slot用法总结
2019/03/09 Javascript
原生JS 实现的input输入时表格过滤操作示例
2019/08/03 Javascript
基于Express框架使用POST传递Form数据
2019/08/10 Javascript
vue+axios实现post文件下载
2019/09/25 Javascript
微信小程序以ssm做后台开发的实现示例
2020/04/08 Javascript
vue Treeselect下拉树只能选择第N级元素实现代码
2020/08/31 Javascript
[37:45]2014 DOTA2国际邀请赛中国区预选赛5.21 DT VS Orenda
2014/05/22 DOTA
python 系统调用的实例详解
2017/07/11 Python
python+matplotlib实现鼠标移动三角形高亮及索引显示
2018/01/15 Python
Python爬取商家联系电话以及各种数据的方法
2018/11/10 Python
python实现翻译word表格小程序
2020/02/27 Python
分享unittest单元测试框架中几种常用的用例加载方法
2020/12/02 Python
html5手机端页面可以向右滑动导致样式受影响的问题
2018/06/20 HTML / CSS
中东地区为妈妈们提供一切的头号购物目的地:Sprii
2018/05/06 全球购物
护士试用期自我鉴定
2014/02/08 职场文书
《维生素c的故事》教学反思
2014/02/18 职场文书
廉洁家庭事迹材料
2014/05/15 职场文书
个人简历自荐信
2014/06/26 职场文书
2014年新农村建设工作总结
2014/12/01 职场文书
英文版辞职信
2015/02/28 职场文书
2015年信息中心工作总结
2015/05/25 职场文书
祝福语集锦:朋友新店开业祝福语
2019/12/10 职场文书
python实现简单反弹球游戏
2021/04/12 Python
Python爬虫网络请求之代理服务器和动态Cookies
2022/04/12 Python