Pytorch 实现focal_loss 多类别和二分类示例


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
# 支持多分类和二分类
class FocalLoss(nn.Module):
  """
  This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
  'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
    Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt)
  :param num_class:
  :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
  :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
          focus on hard misclassified example
  :param smooth: (float,double) smooth value when cross entropy
  :param balance_index: (int) balance class index, should be specific when alpha is float
  :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
  """
 
  def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True):
    super(FocalLoss, self).__init__()
    self.num_class = num_class
    self.alpha = alpha
    self.gamma = gamma
    self.smooth = smooth
    self.size_average = size_average
 
    if self.alpha is None:
      self.alpha = torch.ones(self.num_class, 1)
    elif isinstance(self.alpha, (list, np.ndarray)):
      assert len(self.alpha) == self.num_class
      self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
      self.alpha = self.alpha / self.alpha.sum()
    elif isinstance(self.alpha, float):
      alpha = torch.ones(self.num_class, 1)
      alpha = alpha * (1 - self.alpha)
      alpha[balance_index] = self.alpha
      self.alpha = alpha
    else:
      raise TypeError('Not support alpha type')
 
    if self.smooth is not None:
      if self.smooth < 0 or self.smooth > 1.0:
        raise ValueError('smooth value should be in [0,1]')
 
  def forward(self, input, target):
    logit = F.softmax(input, dim=1)
 
    if logit.dim() > 2:
      # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
      logit = logit.view(logit.size(0), logit.size(1), -1)
      logit = logit.permute(0, 2, 1).contiguous()
      logit = logit.view(-1, logit.size(-1))
    target = target.view(-1, 1)
 
    # N = input.size(0)
    # alpha = torch.ones(N, self.num_class)
    # alpha = alpha * (1 - self.alpha)
    # alpha = alpha.scatter_(1, target.long(), self.alpha)
    epsilon = 1e-10
    alpha = self.alpha
    if alpha.device != input.device:
      alpha = alpha.to(input.device)
 
    idx = target.cpu().long()
    one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
    one_hot_key = one_hot_key.scatter_(1, idx, 1)
    if one_hot_key.device != logit.device:
      one_hot_key = one_hot_key.to(logit.device)
 
    if self.smooth:
      one_hot_key = torch.clamp(
        one_hot_key, self.smooth, 1.0 - self.smooth)
    pt = (one_hot_key * logit).sum(1) + epsilon
    logpt = pt.log()
 
    gamma = self.gamma
 
    alpha = alpha[idx]
    loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
 
    if self.size_average:
      loss = loss.mean()
    else:
      loss = loss.sum()
    return loss
 
 
 
class BCEFocalLoss(torch.nn.Module):
  """
  二分类的Focalloss alpha 固定
  """
  def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean'):
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha
    self.reduction = reduction
 
  def forward(self, _input, target):
    pt = torch.sigmoid(_input)
    alpha = self.alpha
    loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
        (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
    if self.reduction == 'elementwise_mean':
      loss = torch.mean(loss)
    elif self.reduction == 'sum':
      loss = torch.sum(loss)
    return loss

以上这篇Pytorch 实现focal_loss 多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中文编码问题小结
Sep 28 Python
详解Python中find()方法的使用
May 18 Python
使用rst2pdf实现将sphinx生成PDF
Jun 07 Python
python3+PyQt5+Qt Designer实现堆叠窗口部件
Apr 20 Python
Python3多线程操作简单示例
May 22 Python
Python读取Excel表格,并同时画折线图和柱状图的方法
Oct 14 Python
简单了解Python3里的一些新特性
Jul 13 Python
Django中间件基础用法详解
Jul 18 Python
Python 多线程共享变量的实现示例
Apr 17 Python
Python 判断时间是否在时间区间内的实例
May 16 Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 Python
在pycharm创建scrapy项目的实现步骤
Dec 01 Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
基于pytorch的lstm参数使用详解
Jan 14 #Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
You might like
dede全站URL静态化改造[070414更正]
2007/04/17 PHP
PHP编实现程动态图像的创建代码
2008/09/28 PHP
PHP iconv 解决utf-8和gb2312编码转换问题
2010/04/12 PHP
PHP $_FILES函数详解
2011/03/09 PHP
PHP中使用Imagick读取pdf并生成png缩略图实例
2015/01/21 PHP
PHP实现负载均衡下的session共用功能
2018/04/17 PHP
JavaScript使用cookie
2007/02/02 Javascript
Mootools 1.2教程 类(一)
2009/09/15 Javascript
javascript event 事件解析
2011/01/31 Javascript
获取div编辑框,textarea,input text的光标位置 兼容IE,FF和Chrome的方法介绍
2012/11/08 Javascript
javascript计算用户打开网页的停留时间
2014/01/09 Javascript
javascript中match函数的用法小结
2014/02/08 Javascript
基于javascript实现判断移动终端浏览器版本信息
2014/12/09 Javascript
JavaScript实现自动弹出窗口并自动关闭窗口的方法
2015/08/06 Javascript
整理Javascript数组学习笔记
2015/11/29 Javascript
node.js使用cluster实现多进程
2016/03/17 Javascript
json实现添加、遍历与删除属性的方法
2016/06/17 Javascript
微信小程序 两种为对象属性赋值的方式详解
2017/02/23 Javascript
JavaScript编写棋盘覆盖代码详解
2017/08/28 Javascript
json2.js 入门教程之使用方法与实例分析
2017/09/14 Javascript
深入浅析Vue中的slots/scoped slots
2018/04/03 Javascript
使用webpack搭建vue项目实现脚手架功能
2019/03/15 Javascript
[02:40]DOTA2英雄基础教程 巨牙海民
2013/12/23 DOTA
[57:47]Fnatic vs Winstrike 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
利用Python+Java调用Shell脚本时的死锁陷阱详解
2018/01/24 Python
Django shell调试models输出的SQL语句方法
2019/08/29 Python
Python中if有多个条件处理方法
2020/02/26 Python
Python中三维坐标空间绘制的实现
2020/09/22 Python
女性时尚网购:Chic Me
2019/07/30 全球购物
优秀员工年终发言演讲稿
2014/01/01 职场文书
中专生自我鉴定范文
2014/02/02 职场文书
公益活动邀请函
2014/02/05 职场文书
揭牌仪式主持词
2014/03/19 职场文书
开展党的群众路线教育实践活动个人对照检查材料
2014/11/05 职场文书
《伯牙绝弦》教学反思
2016/02/16 职场文书
Python-OpenCV实现图像缺陷检测的实例
2021/06/11 Python