pytorch 多分类问题,计算百分比操作


Posted in Python onJuly 09, 2020

二分类或分类问题,网络输出为二维矩阵:批次x几分类,最大的为当前分类,标签为one-hot型的二维矩阵:批次x几分类

计算百分比有numpy和pytorch两种实现方案实现,都是根据索引计算百分比,以下为具体二分类实现过程。

pytorch

out = torch.Tensor([[0,3],
     [2,3],
     [1,0],
     [3,4]])
cond = torch.Tensor([[1,0],
      [0,1],
      [1,0],
      [1,0]])
 
persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
  [2, 3],
  [1, 0],
  [3, 4]]
cond = [[1, 0],
  [0, 1],
  [1, 0],
  [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)

补充知识:python 多分类画auc曲线和macro-average ROC curve

最近帮一个人做了一个多分类画auc曲线的东西,不过最后那个人不要了,还被说了一顿,心里很是不爽,anyway,我写代码的还是要继续写代码的,所以我准备把我修改的代码分享开来,供大家研究学习。处理的数据大改是这种xlsx文件:

IMAGE y_real y_predict 0其他 1豹纹 2弥漫 3斑片 4黄斑
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005111 (Copy).jpg 0 0 1 8.31E-19 7.59E-13 4.47E-15 2.46E-14
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM005201 (Copy).jpg 0 0 1 5.35E-17 4.38E-11 8.80E-13 3.85E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004938 (4) (Copy).jpg 0 0 1 1.20E-16 3.17E-11 6.26E-12 1.02E-11
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004349 (3) (Copy).jpg 0 0 1 5.66E-14 1.87E-09 6.50E-09 3.29E-09
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004673 (5) (Copy).jpg 0 0 1 5.51E-17 9.30E-12 1.33E-13 2.54E-12
/mnt/AI/HM/izy20200531c5/299/train/0其他/IM004450 (5) (Copy).jpg 0 0 1 4.81E-17 3.75E-12 3.96E-13 6.17E-13

导入基础的pandas和keras处理函数

import pandas as pd

from keras.utils import to_categorical

导入数据

data=pd.read_excel('5分类新.xlsx')

data.head()

导入机器学习库

from sklearn.metrics import precision_recall_curve
import numpy as np
from matplotlib import pyplot
from sklearn.metrics import f1_score
from sklearn.metrics import roc_curve, auc

把ground truth提取出来

true_y=data[' y_real'].to_numpy()

true_y=to_categorical(true_y)

把每个类别的数据提取出来

PM_y=data[[' 0其他',' 1豹纹',' 2弥漫',' 3斑片',' 4黄斑']].to_numpy()

PM_y.shape

计算每个类别的fpr和tpr

n_classes=PM_y.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
 fpr[i], tpr[i], _ = roc_curve(true_y[:, i], PM_y[:, i])
 roc_auc[i] = auc(fpr[i], tpr[i])

计算macro auc

from scipy import interp
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
 
# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
 mean_tpr += interp(all_fpr, fpr[i], tpr[i])
 
# Finally average it and compute AUC
mean_tpr /= n_classes
 
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

画图

import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.ticker import FuncFormatter
lw = 2
# Plot all ROC curves
plt.figure()
labels=['Category 0','Category 1','Category 2','Category 3','Category 4']
plt.plot(fpr["macro"], tpr["macro"],
   label='macro-average ROC curve (area = {0:0.4f})'
    ''.format(roc_auc["macro"]),
   color='navy', linestyle=':', linewidth=4)
 
colors = cycle(['aqua', 'darkorange', 'cornflowerblue','blue','yellow'])
for i, color in zip(range(n_classes), colors):
 plt.plot(fpr[i], tpr[i], color=color, lw=lw,
    label=labels[i]+'(area = {0:0.4f})'.format(roc_auc[i]))
 
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity (%)')
plt.ylabel('Sensitivity (%)')
plt.title('Some extension of Receiver operating characteristic to multi-class')
def to_percent(temp, position):
 return '%1.0f'%(100*temp)
plt.gca().yaxis.set_major_formatter(FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(FuncFormatter(to_percent))
plt.legend(loc="lower right")
plt.show()

展示

pytorch 多分类问题,计算百分比操作

上述的代码是在jupyter中运行的,所以是分开的

以上这篇pytorch 多分类问题,计算百分比操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
举例讲解Python的Tornado框架实现数据可视化的教程
May 02 Python
python实时分析日志的一个小脚本分享
May 07 Python
Python入门_学会创建并调用函数的方法
May 16 Python
python实现自动发送邮件发送多人、群发、多附件的示例
Jan 23 Python
详解python Todo清单实战
Nov 01 Python
python爬虫之urllib库常用方法用法总结大全
Nov 14 Python
python3.8 微信发送服务器监控报警消息代码实现
Nov 05 Python
python扫描线填充算法详解
Feb 19 Python
Python如何爬取qq音乐歌词到本地
Jun 01 Python
在Keras中利用np.random.shuffle()打乱数据集实例
Jun 15 Python
python pandas dataframe 去重函数的具体使用
Jul 20 Python
如何通过Python实现RabbitMQ延迟队列
Nov 28 Python
详解Python 循环嵌套
Jul 09 #Python
keras分类之二分类实例(Cat and dog)
Jul 09 #Python
python中tkinter窗口位置\坐标\大小等实现示例
Jul 09 #Python
Python2.x与3​​.x版本有哪些区别
Jul 09 #Python
浅谈keras中Dropout在预测过程中是否仍要起作用
Jul 09 #Python
在keras中对单一输入图像进行预测并返回预测结果操作
Jul 09 #Python
python求解汉诺塔游戏
Jul 09 #Python
You might like
第十一节 重载 [11]
2006/10/09 PHP
php使用ICQ网关发送手机短信
2013/10/30 PHP
PHP网页游戏学习之Xnova(ogame)源码解读(八)
2014/06/23 PHP
php获取URL中带#号等特殊符号参数的解决方法
2014/09/02 PHP
Laravel执行migrate命令提示:No such file or directory的解决方法
2016/03/16 PHP
Laravel使用swoole实现websocket主动消息推送的方法介绍
2019/10/20 PHP
根据对象的某一属性进行排序的js代码(如:name,age)
2010/08/10 Javascript
Javascript核心读书有感之语言核心
2015/02/01 Javascript
JS使用ajax从xml文件动态获取数据显示的方法
2015/03/24 Javascript
jquery+html5制作超酷的圆盘时钟表
2015/04/14 Javascript
文字垂直滚动之javascript代码
2015/07/29 Javascript
Laravel中常见的错误与解决方法小结
2016/08/30 Javascript
移动端js触摸事件详解
2016/09/18 Javascript
Javascript中的神器——Promise
2017/02/08 Javascript
vue2.0实现倒计时的插件(时间戳 刷新 跳转 都不影响)
2017/03/30 Javascript
Vue 2.0的数据依赖实现原理代码简析
2017/07/10 Javascript
react实现一个优雅的图片占位模块组件详解
2017/10/30 Javascript
用jquery获取select标签中选中的option值及文本的示例
2018/01/25 jQuery
JavaScript折半查找(二分查找)算法原理与实现方法示例
2018/08/06 Javascript
Vue2.0中三种常用传值方式(父传子、子传父、非父子组件传值)
2018/08/16 Javascript
详解Vue CLI3 多页应用实践和源码设计
2018/08/30 Javascript
详解element-ui级联菜单(城市三级联动菜单)和回显问题
2019/10/02 Javascript
基于Vant UI框架实现时间段选择器
2020/12/24 Javascript
python ansible服务及剧本编写
2017/12/29 Python
Django之模型层多表操作的实现
2019/01/08 Python
pytorch实现Tensor变量之间的转换
2020/02/17 Python
python分别打包出32位和64位应用程序
2020/02/18 Python
Python使用xpath实现图片爬取
2020/09/16 Python
利用CSS3实现毛玻璃效果示例源码
2016/09/25 HTML / CSS
中国跨境电商:Tomtop
2017/03/16 全球购物
加拿大床上用品、家居装饰、厨房和浴室产品购物网站:Linen Chest
2018/06/05 全球购物
国际性能运动服装品牌:Dare 2b
2018/07/27 全球购物
列车乘务员工作不细心检讨书
2014/10/07 职场文书
z-index不起作用
2021/03/31 HTML / CSS
sql时间段切分实现每隔x分钟出一份高速门架车流量
2022/02/28 SQL Server
Spring Data JPA框架的核心概念和Repository接口
2022/04/28 Java/Android