Pytorch 实现focal_loss 多类别和二分类示例


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
# 支持多分类和二分类
class FocalLoss(nn.Module):
  """
  This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
  'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
    Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt)
  :param num_class:
  :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
  :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
          focus on hard misclassified example
  :param smooth: (float,double) smooth value when cross entropy
  :param balance_index: (int) balance class index, should be specific when alpha is float
  :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
  """
 
  def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True):
    super(FocalLoss, self).__init__()
    self.num_class = num_class
    self.alpha = alpha
    self.gamma = gamma
    self.smooth = smooth
    self.size_average = size_average
 
    if self.alpha is None:
      self.alpha = torch.ones(self.num_class, 1)
    elif isinstance(self.alpha, (list, np.ndarray)):
      assert len(self.alpha) == self.num_class
      self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
      self.alpha = self.alpha / self.alpha.sum()
    elif isinstance(self.alpha, float):
      alpha = torch.ones(self.num_class, 1)
      alpha = alpha * (1 - self.alpha)
      alpha[balance_index] = self.alpha
      self.alpha = alpha
    else:
      raise TypeError('Not support alpha type')
 
    if self.smooth is not None:
      if self.smooth < 0 or self.smooth > 1.0:
        raise ValueError('smooth value should be in [0,1]')
 
  def forward(self, input, target):
    logit = F.softmax(input, dim=1)
 
    if logit.dim() > 2:
      # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
      logit = logit.view(logit.size(0), logit.size(1), -1)
      logit = logit.permute(0, 2, 1).contiguous()
      logit = logit.view(-1, logit.size(-1))
    target = target.view(-1, 1)
 
    # N = input.size(0)
    # alpha = torch.ones(N, self.num_class)
    # alpha = alpha * (1 - self.alpha)
    # alpha = alpha.scatter_(1, target.long(), self.alpha)
    epsilon = 1e-10
    alpha = self.alpha
    if alpha.device != input.device:
      alpha = alpha.to(input.device)
 
    idx = target.cpu().long()
    one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
    one_hot_key = one_hot_key.scatter_(1, idx, 1)
    if one_hot_key.device != logit.device:
      one_hot_key = one_hot_key.to(logit.device)
 
    if self.smooth:
      one_hot_key = torch.clamp(
        one_hot_key, self.smooth, 1.0 - self.smooth)
    pt = (one_hot_key * logit).sum(1) + epsilon
    logpt = pt.log()
 
    gamma = self.gamma
 
    alpha = alpha[idx]
    loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
 
    if self.size_average:
      loss = loss.mean()
    else:
      loss = loss.sum()
    return loss
 
 
 
class BCEFocalLoss(torch.nn.Module):
  """
  二分类的Focalloss alpha 固定
  """
  def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean'):
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha
    self.reduction = reduction
 
  def forward(self, _input, target):
    pt = torch.sigmoid(_input)
    alpha = self.alpha
    loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
        (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
    if self.reduction == 'elementwise_mean':
      loss = torch.mean(loss)
    elif self.reduction == 'sum':
      loss = torch.sum(loss)
    return loss

以上这篇Pytorch 实现focal_loss 多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的__SLOTS__属性使用示例
Feb 18 Python
Python实现合并字典的方法
Jul 07 Python
Python+matplotlib绘制不同大小和颜色散点图实例
Jan 19 Python
Numpy array数据的增、删、改、查实例
Jun 04 Python
pandas 条件搜索返回列表的方法
Oct 30 Python
Djang的model创建的字段和参数详解
Jul 27 Python
django+echart数据动态显示的例子
Aug 12 Python
python如何从文件读取数据及解析
Sep 19 Python
Python自动采集微信联系人的实现示例
Feb 28 Python
keras 多gpu并行运行案例
Jun 10 Python
python连接mysql有哪些方法
Jun 24 Python
python中count函数知识点浅析
Dec 17 Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
基于pytorch的lstm参数使用详解
Jan 14 #Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
You might like
上海牌131型七灯四波段四喇叭一级收音机
2021/03/02 无线电
php UTF-8、Unicode和BOM问题
2010/05/18 PHP
php遍历数组的方法分享
2012/03/22 PHP
php 计划任务 检测用户连接状态
2012/03/29 PHP
PHP中的Streams详细介绍
2014/11/12 PHP
php实现模拟post请求用法实例
2015/07/11 PHP
PHP制作登录异常ip检测功能的实例代码
2016/11/16 PHP
laravel按天、按小时,查询数据的实例
2019/10/09 PHP
FireFox与IE 下js兼容触发click事件的代码
2008/11/20 Javascript
jQuery+css3实现Ajax点击后动态删除功能的方法
2015/08/10 Javascript
Bootstrap 粘页脚效果
2016/03/28 Javascript
第一次接触Bootstrap框架
2016/10/24 Javascript
AngularJS之页面跳转Route实例代码
2017/03/10 Javascript
基于Node.js模板引擎教程-jade速学与实战1
2017/09/17 Javascript
node中间层实现文件上传功能
2018/06/11 Javascript
Vue CLI 3.x 自动部署项目至服务器的方法
2019/04/02 Javascript
vue elementUI 表单校验功能之数组多层嵌套
2019/06/04 Javascript
layer关闭当前窗口页面以及确认取消按钮的方法
2019/09/09 Javascript
[53:36]Liquid vs VP Supermajor决赛 BO 第三场 6.10
2018/07/05 DOTA
基于Python的接口测试框架实例
2016/11/04 Python
python写日志文件操作类与应用示例
2019/07/01 Python
用python介绍4种常用的单链表翻转的方法小结
2020/02/24 Python
python openssl模块安装及用法
2020/12/06 Python
python+selenium爬取微博热搜存入Mysql的实现方法
2021/01/27 Python
详解Html5页面实现下载文件(apk、txt等)的三种方式
2018/10/22 HTML / CSS
Perricone MD裴礼康美国官网:抗衰老护肤品
2016/09/26 全球购物
节省高达65%的城市景点费用:Go City
2019/07/06 全球购物
澳大利亚100%丝绸多彩度假装商店:TheSwankStore
2019/09/04 全球购物
北京银河万佳Java面试题
2012/03/21 面试题
好的演讲稿开场白
2013/12/30 职场文书
劳资协议书范本
2014/04/23 职场文书
党政领导班子四风问题对照检查材料思想汇报
2014/10/02 职场文书
党员干部对十八届四中全会的期盼
2014/10/17 职场文书
导游词之吉林吉塔
2019/11/11 职场文书
浅谈CSS不规则边框的生成方案
2021/05/25 HTML / CSS
Navicat Premium自定义 sql 标签的创建方式
2022/09/23 数据库