Pytorch 实现focal_loss 多类别和二分类示例


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
# 支持多分类和二分类
class FocalLoss(nn.Module):
  """
  This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
  'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
    Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt)
  :param num_class:
  :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
  :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
          focus on hard misclassified example
  :param smooth: (float,double) smooth value when cross entropy
  :param balance_index: (int) balance class index, should be specific when alpha is float
  :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
  """
 
  def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True):
    super(FocalLoss, self).__init__()
    self.num_class = num_class
    self.alpha = alpha
    self.gamma = gamma
    self.smooth = smooth
    self.size_average = size_average
 
    if self.alpha is None:
      self.alpha = torch.ones(self.num_class, 1)
    elif isinstance(self.alpha, (list, np.ndarray)):
      assert len(self.alpha) == self.num_class
      self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
      self.alpha = self.alpha / self.alpha.sum()
    elif isinstance(self.alpha, float):
      alpha = torch.ones(self.num_class, 1)
      alpha = alpha * (1 - self.alpha)
      alpha[balance_index] = self.alpha
      self.alpha = alpha
    else:
      raise TypeError('Not support alpha type')
 
    if self.smooth is not None:
      if self.smooth < 0 or self.smooth > 1.0:
        raise ValueError('smooth value should be in [0,1]')
 
  def forward(self, input, target):
    logit = F.softmax(input, dim=1)
 
    if logit.dim() > 2:
      # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
      logit = logit.view(logit.size(0), logit.size(1), -1)
      logit = logit.permute(0, 2, 1).contiguous()
      logit = logit.view(-1, logit.size(-1))
    target = target.view(-1, 1)
 
    # N = input.size(0)
    # alpha = torch.ones(N, self.num_class)
    # alpha = alpha * (1 - self.alpha)
    # alpha = alpha.scatter_(1, target.long(), self.alpha)
    epsilon = 1e-10
    alpha = self.alpha
    if alpha.device != input.device:
      alpha = alpha.to(input.device)
 
    idx = target.cpu().long()
    one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
    one_hot_key = one_hot_key.scatter_(1, idx, 1)
    if one_hot_key.device != logit.device:
      one_hot_key = one_hot_key.to(logit.device)
 
    if self.smooth:
      one_hot_key = torch.clamp(
        one_hot_key, self.smooth, 1.0 - self.smooth)
    pt = (one_hot_key * logit).sum(1) + epsilon
    logpt = pt.log()
 
    gamma = self.gamma
 
    alpha = alpha[idx]
    loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
 
    if self.size_average:
      loss = loss.mean()
    else:
      loss = loss.sum()
    return loss
 
 
 
class BCEFocalLoss(torch.nn.Module):
  """
  二分类的Focalloss alpha 固定
  """
  def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean'):
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha
    self.reduction = reduction
 
  def forward(self, _input, target):
    pt = torch.sigmoid(_input)
    alpha = self.alpha
    loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
        (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
    if self.reduction == 'elementwise_mean':
      loss = torch.mean(loss)
    elif self.reduction == 'sum':
      loss = torch.sum(loss)
    return loss

以上这篇Pytorch 实现focal_loss 多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Flask框架信号用法实例分析
Jul 24 Python
Tesserocr库的正确安装方式
Oct 19 Python
python批量下载网站马拉松照片的完整步骤
Dec 05 Python
对python以16进制打印字节数组的方法详解
Jan 24 Python
Python 脚本实现淘宝准点秒杀功能
Nov 13 Python
Python3 shutil(高级文件操作模块)实例用法总结
Feb 19 Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 Python
python 实现分组求和与分组累加求和代码
May 18 Python
Python偏函数Partial function使用方法实例详解
Jun 17 Python
python右对齐的实例方法
Jul 05 Python
python 制作一个gui界面的翻译工具
May 14 Python
Python数据可视化之基于pyecharts实现的地理图表的绘制
Jun 10 Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
基于pytorch的lstm参数使用详解
Jan 14 #Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
You might like
聊天室php&amp;mysql(三)
2006/10/09 PHP
php上传文件中文文件名乱码的解决方法
2013/11/01 PHP
php判断/计算闰年的方法小结【三种方法】
2019/07/06 PHP
redis+php实现微博(三)微博列表功能详解
2019/09/23 PHP
PHP数组基本用法与知识点总结
2020/06/02 PHP
Jquery 弹出层插件实现代码
2009/10/24 Javascript
Jvascript学习实践案例(开发常用)
2012/06/25 Javascript
javascript动画浅析
2012/08/30 Javascript
用js来获取上传的文件名纯粹是为了美化而用
2013/10/23 Javascript
获取3个数组不重复的值的具体实现
2013/12/30 Javascript
Jquery easyUI 更新行示例
2014/03/06 Javascript
jquery实现自定义图片裁剪功能【推荐】
2017/03/08 Javascript
JS如何设置元素样式的方法示例
2017/08/28 Javascript
高性能的javascript之加载顺序与执行原理篇
2018/01/14 Javascript
Node.js的Koa实现JWT用户认证方法
2018/05/05 Javascript
使用elementUI实现将图片上传到本地的示例
2018/09/04 Javascript
JS 实现微信扫一扫功能
2018/09/14 Javascript
Vue 中可以定义组件模版的几种方式
2019/08/06 Javascript
[52:29]DOTA2上海特级锦标赛主赛事日 - 2 胜者组第一轮#3Secret VS OG第三局
2016/03/03 DOTA
python消费kafka数据批量插入到es的方法
2018/12/27 Python
Python实现的大数据分析操作系统日志功能示例
2019/02/11 Python
Python3.5面向对象与继承图文实例详解
2019/04/24 Python
django框架事务处理小结【ORM 事务及raw sql,customize sql 事务处理】
2019/06/27 Python
在脚本中单独使用django的ORM模型详解
2020/04/01 Python
简单了解Python字典copy与赋值的区别
2020/09/16 Python
你应该知道的30个css选择器
2014/03/19 HTML / CSS
印尼极简主义和实惠的在线家具店:Fabelio
2019/03/27 全球购物
乌克兰电子产品和家用电器购物网站:TOUCH
2019/08/09 全球购物
工业自动化专业毕业生推荐信
2013/11/18 职场文书
麦当劳辞职信范文
2014/01/18 职场文书
优秀教师事迹材料
2014/12/15 职场文书
2014年语文教师工作总结
2014/12/18 职场文书
经理聘任证明
2015/03/02 职场文书
体检通知范文
2015/04/21 职场文书
小学班长竞选稿
2015/11/20 职场文书
Anaconda配置各版本Pytorch的实现
2021/08/07 Python