pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python使用win32com在百度空间插入html元素示例
Feb 20 Python
深入理解Python中装饰器的用法
Jun 28 Python
python+opencv实现的简单人脸识别代码示例
Nov 14 Python
Java分治归并排序算法实例详解
Dec 12 Python
Python实现连接两个无规则列表后删除重复元素并升序排序的方法
Feb 05 Python
python中dir()与__dict__属性的区别浅析
Dec 10 Python
对python使用telnet实现弱密码登录的方法详解
Jan 26 Python
Python 用三行代码提取PDF表格数据
Oct 13 Python
python计算导数并绘图的实例
Feb 29 Python
matplotlib 三维图表绘制方法简介
Sep 20 Python
Python3+PyCharm+Django+Django REST framework配置与简单开发教程
Feb 16 Python
OpenCV 图像梯度的实现方法
Jul 25 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
PHP如何得到当前页和上一页的地址?
2006/11/27 PHP
PHP 在5.1.* 和5.2.*之间 PDO数据库操作中的不同之处小结
2012/03/07 PHP
深入for,while,foreach遍历时间比较的详解
2013/06/08 PHP
ThinkPHP中html:list标签用法分析
2016/01/09 PHP
PHP7实现和CryptoJS的AES加密方式互通示例【AES-128-ECB加密】
2019/06/08 PHP
jQuery 性能优化指南(3)
2009/05/21 Javascript
javascript+css 网页每次加载不同样式的实现方法
2009/12/27 Javascript
JS获取文本框,下拉框,单选框的值的简单实例
2014/02/26 Javascript
一个JavaScript的求爱小特效
2014/05/09 Javascript
node.js中的path.delimiter方法使用说明
2014/12/09 Javascript
在JS中操作时间之getUTCMilliseconds()方法的使用
2015/06/10 Javascript
jquery实现简易的移动端验证表单
2015/11/08 Javascript
javascript实现省市区三级联动下拉框菜单
2015/11/17 Javascript
jquery密码强度校验
2015/12/02 Javascript
JS原型、原型链深入理解
2016/02/27 Javascript
WordPress 单页面上一页下一页的实现方法【附代码】
2016/03/10 Javascript
H5实现中奖记录逐行滚动切换效果
2017/03/13 Javascript
Javascript前端经典的面试题及答案
2017/03/14 Javascript
p5.js入门教程之鼠标交互的示例
2018/03/16 Javascript
vue cli 3.0 使用全过程解析
2018/06/14 Javascript
使用Vue-cli 中为单独页面设置背景图片铺满全屏
2020/07/17 Javascript
PyQt5实现无边框窗口的标题拖动和窗口缩放
2018/04/19 Python
Python登录注册验证功能实现
2018/06/18 Python
Python学习笔记之For循环用法详解
2019/08/14 Python
使用python 的matplotlib 画轨道实例
2020/01/19 Python
python读取xml文件方法解析
2020/08/04 Python
丝芙兰新加坡官网:Sephora新加坡
2018/12/04 全球购物
C#里面如何判断一个Object是否是某种类型(如Boolean)?
2016/02/10 面试题
Servlet如何得到服务器的信息
2015/12/22 面试题
2014教师研修学习体会
2014/07/08 职场文书
助人为乐好少年事迹材料
2014/08/18 职场文书
公司门卫岗位职责
2015/04/13 职场文书
2015年学生会部门工作总结
2015/04/21 职场文书
2016年幼儿园教师师德承诺书
2016/03/25 职场文书
导游词之上海东方明珠塔
2019/09/25 职场文书
Java Spring 控制反转(IOC)容器详解
2021/10/05 Java/Android