Pytorch 实现focal_loss 多类别和二分类示例


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
# 支持多分类和二分类
class FocalLoss(nn.Module):
  """
  This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
  'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
    Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt)
  :param num_class:
  :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
  :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
          focus on hard misclassified example
  :param smooth: (float,double) smooth value when cross entropy
  :param balance_index: (int) balance class index, should be specific when alpha is float
  :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
  """
 
  def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True):
    super(FocalLoss, self).__init__()
    self.num_class = num_class
    self.alpha = alpha
    self.gamma = gamma
    self.smooth = smooth
    self.size_average = size_average
 
    if self.alpha is None:
      self.alpha = torch.ones(self.num_class, 1)
    elif isinstance(self.alpha, (list, np.ndarray)):
      assert len(self.alpha) == self.num_class
      self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
      self.alpha = self.alpha / self.alpha.sum()
    elif isinstance(self.alpha, float):
      alpha = torch.ones(self.num_class, 1)
      alpha = alpha * (1 - self.alpha)
      alpha[balance_index] = self.alpha
      self.alpha = alpha
    else:
      raise TypeError('Not support alpha type')
 
    if self.smooth is not None:
      if self.smooth < 0 or self.smooth > 1.0:
        raise ValueError('smooth value should be in [0,1]')
 
  def forward(self, input, target):
    logit = F.softmax(input, dim=1)
 
    if logit.dim() > 2:
      # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
      logit = logit.view(logit.size(0), logit.size(1), -1)
      logit = logit.permute(0, 2, 1).contiguous()
      logit = logit.view(-1, logit.size(-1))
    target = target.view(-1, 1)
 
    # N = input.size(0)
    # alpha = torch.ones(N, self.num_class)
    # alpha = alpha * (1 - self.alpha)
    # alpha = alpha.scatter_(1, target.long(), self.alpha)
    epsilon = 1e-10
    alpha = self.alpha
    if alpha.device != input.device:
      alpha = alpha.to(input.device)
 
    idx = target.cpu().long()
    one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
    one_hot_key = one_hot_key.scatter_(1, idx, 1)
    if one_hot_key.device != logit.device:
      one_hot_key = one_hot_key.to(logit.device)
 
    if self.smooth:
      one_hot_key = torch.clamp(
        one_hot_key, self.smooth, 1.0 - self.smooth)
    pt = (one_hot_key * logit).sum(1) + epsilon
    logpt = pt.log()
 
    gamma = self.gamma
 
    alpha = alpha[idx]
    loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
 
    if self.size_average:
      loss = loss.mean()
    else:
      loss = loss.sum()
    return loss
 
 
 
class BCEFocalLoss(torch.nn.Module):
  """
  二分类的Focalloss alpha 固定
  """
  def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean'):
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha
    self.reduction = reduction
 
  def forward(self, _input, target):
    pt = torch.sigmoid(_input)
    alpha = self.alpha
    loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
        (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
    if self.reduction == 'elementwise_mean':
      loss = torch.mean(loss)
    elif self.reduction == 'sum':
      loss = torch.sum(loss)
    return loss

以上这篇Pytorch 实现focal_loss 多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python监控网站运行异常并发送邮件的方法
Mar 13 Python
粗略分析Python中的内存泄漏
Apr 23 Python
Python守护进程用法实例分析
Jun 04 Python
Python的SQLalchemy模块连接与操作MySQL的基础示例
Jul 11 Python
Python matplotlib以日期为x轴作图代码实例
Nov 22 Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 Python
适合Python初学者的一些编程技巧
Feb 12 Python
Python threading.local代码实例及原理解析
Mar 16 Python
Python中操作各种多媒体,视频、音频到图片的代码详解
Jun 04 Python
Python爬虫与反爬虫大战
Jul 30 Python
pycharm 实现光标快速移动到括号外或行尾的操作
Feb 05 Python
Python turtle实现贪吃蛇游戏
Jun 18 Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
基于pytorch的lstm参数使用详解
Jan 14 #Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
You might like
php版微信返回用户text输入的方法
2016/11/14 PHP
用于判断用户注册时,密码强度的JS代码
2009/01/01 Javascript
超越Jquery_01_isPlainObject分析与重构
2010/10/20 Javascript
JS获取月的最后一天与JS得到一个月份最大天数的实例代码
2013/12/16 Javascript
javascript删除数组元素并且数组长度减小的简单实例
2014/02/14 Javascript
Javascript中setTimeOut和setInterval的定时器用法
2015/06/12 Javascript
深入理解node exports和module.exports区别
2016/06/01 Javascript
jQuery简单实现仿京东分类导航层效果
2016/06/07 Javascript
bootstrap paginator分页插件的两种使用方式实例详解
2017/11/14 Javascript
详解从Vue-router到html5的pushState
2018/07/21 Javascript
Vue中android4.4不兼容问题的解决方法
2018/09/04 Javascript
vue2.0 + ele的循环表单及验证字段方法
2018/09/18 Javascript
基于vue+axios+lrz.js微信端图片压缩上传方法
2019/06/25 Javascript
nodejs环境使用Typeorm连接查询Oracle数据
2019/12/05 NodeJs
vue 数据遍历筛选 过滤 排序的应用操作
2020/11/17 Javascript
[04:44]显微镜下的DOTA2第二期——你所没有注意到的细节
2014/06/20 DOTA
[01:10:27]DOTA2-DPC中国联赛正赛 SAG vs XG BO3 第二场 3月5日
2021/03/11 DOTA
py2exe 编译ico图标的代码
2013/03/08 Python
在Python程序中进行文件读取和写入操作的教程
2015/04/28 Python
Python中安装easy_install的方法
2018/11/18 Python
使用Keras中的ImageDataGenerator进行批次读图方式
2020/06/17 Python
python实现批量命名照片
2020/06/18 Python
keras训练浅层卷积网络并保存和加载模型实例
2020/07/02 Python
python中的split、rsplit、splitlines用法说明
2020/10/23 Python
Python中lru_cache的使用和实现详解
2021/01/25 Python
利用CSS3的border-radius绘制太极及爱心图案示例
2016/05/17 HTML / CSS
巴西女装购物网站:Eclectic
2018/04/24 全球购物
意大利在线眼镜精品店:Ottica Lipari
2019/11/11 全球购物
餐饮部总监岗位职责范文
2014/02/13 职场文书
党支部承诺书范文
2014/03/28 职场文书
心理健康活动总结
2014/04/30 职场文书
2014年建筑工程工作总结
2014/12/03 职场文书
2015元旦晚会主持词(开场白+结束语)
2014/12/14 职场文书
Python Pandas读取Excel日期数据的异常处理方法
2022/02/28 Python
世界十大评分最高的动漫,CLANNAD上榜,第八赚足人们眼泪
2022/03/18 日漫
Hive常用日期格式转换语法
2022/06/25 数据库