pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python获取Windows或Linux主机名称通用函数分享
Nov 22 Python
Python异常学习笔记
Feb 03 Python
python多线程方式执行多个bat代码
Jun 07 Python
Python3网络爬虫之使用User Agent和代理IP隐藏身份
Nov 23 Python
Python中字典的浅拷贝与深拷贝用法实例分析
Jan 02 Python
python之super的使用小结
Aug 13 Python
Python 私有化操作实例分析
Nov 21 Python
pytorch中torch.max和Tensor.view函数用法详解
Jan 03 Python
解决python 虚拟环境删除包无法加载的问题
Jul 13 Python
python 获取域名到期时间的方法步骤
Feb 10 Python
详解python字符串驻留技术
May 21 Python
Python 中面向接口编程
May 20 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
关于二级目录拖拽排序的实现(源码示例下载)
2013/04/26 PHP
PHP fopen()和 file_get_contents()应用与差异介绍
2014/03/19 PHP
PHP的MVC模式实现原理分析(一相简单的MVC框架范例)
2014/04/29 PHP
Yii控制器中操作视图js的方法
2016/07/04 PHP
PHP依赖注入原理与用法分析
2018/08/21 PHP
Laravel框架处理用户的请求操作详解
2019/12/20 PHP
jQuery实现可拖动的浮动层完整代码
2013/05/27 Javascript
网站如何做到完全不需要jQuery也可以满足简单需求
2013/06/27 Javascript
js实现select跳转功能代码
2014/10/22 Javascript
正则表达式优化JSON字符串的技巧
2015/12/24 Javascript
Bootstrap每天必学之折叠
2016/04/12 Javascript
AngularJS控制器继承自另一控制器
2016/05/09 Javascript
漂亮! js实现颜色渐变效果
2016/08/12 Javascript
JS函数多个参数默认值指定方法分析
2016/11/28 Javascript
Jquery Easyui验证组件ValidateBox使用详解(20)
2016/12/18 Javascript
AngularJS实现进度条功能示例
2017/07/05 Javascript
深入理解Vue2.x的虚拟DOM diff原理
2017/09/27 Javascript
JavaScript捕捉事件和阻止冒泡事件实例分析
2018/08/03 Javascript
React实现全选功能
2020/08/25 Javascript
JQuery基于FormData异步提交数据文件
2020/09/01 jQuery
Vue 数据绑定的原理分析
2020/11/16 Javascript
python+Django+apache的配置方法详解
2016/06/01 Python
解决pandas中读取中文名称的csv文件报错的问题
2018/07/04 Python
利用Python进行数据可视化常见的9种方法!超实用!
2018/07/11 Python
使用Python的toolz库开始函数式编程的方法
2018/11/15 Python
利用python将图片版PDF转文字版PDF
2019/05/03 Python
使用Python和Scribus创建一个RGB立方体的方法
2019/07/17 Python
python3 requests库实现多图片爬取教程
2019/12/18 Python
python进行OpenCV实战之画图(直线、矩形、圆形)
2020/08/27 Python
Python Process创建进程的2种方法详解
2021/01/25 Python
基于ccs3的timeline时间线实现方法
2020/04/30 HTML / CSS
TripAdvisor西班牙官方网站:全球领先的旅游网站
2018/01/10 全球购物
什么是Web Service?
2012/07/25 面试题
计算机专业应届毕业生自荐信
2013/09/26 职场文书
岗位职责范本
2013/11/23 职场文书
拾金不昧感谢信范文
2015/01/21 职场文书