pytorch分类模型绘制混淆矩阵以及可视化详解


Posted in Python onApril 07, 2022

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵
 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds)
 # 使用torch.no_grad()可以显著降低测试用例的GPU占用
    with torch.no_grad():
        for step, (imgs, targets) in enumerate(test_loader):
            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor
            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉
            targets = targets.squeeze()  # [50,1] ----->  [50]

            # 将变量转为gpu
            targets = targets.cuda()
            imgs = imgs.cuda()
            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())
            
            out = model(imgs)
            #记录混淆矩阵参数
            conf_matrix = confusion_matrix(out, targets, conf_matrix)
            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到np
corrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数
per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数

 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num))
 print(conf_matrix)

 # 获取每种Emotion的识别准确率
 print("每种情感总个数:",per_kinds)
 print("每种情感预测正确的个数:",corrects)
 print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵以及可视化详解

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵
Emotion=8#这个数值是具体的分类数,大家可以自行修改
labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签

# 显示数据
plt.imshow(conf_matrix, cmap=plt.cm.Blues)

# 在图中标注数量/概率信息
thresh = conf_matrix.max() / 2	#数值颜色阈值,如果数值超过这个,就颜色加深。
for x in range(Emotion_kinds):
    for y in range(Emotion_kinds):
        # 注意这里的matrix[y, x]不是matrix[x, y]
        info = int(conf_matrix[y, x])
        plt.text(x, y, info,
                 verticalalignment='center',
                 horizontalalignment='center',
                 color="white" if info > thresh else "black")
                 
plt.tight_layout()#保证图不重叠
plt.yticks(range(Emotion_kinds), labels)
plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°
plt.show()
plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵以及可视化详解

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵以及可视化详解

总结

到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python中精确输出JSON浮点数的方法
Apr 18 Python
python实现html转ubb代码(html2ubb)
Jul 03 Python
Python实现1-9数组形成的结果为100的所有运算式的示例
Nov 03 Python
pyqt5 键盘监听按下enter 就登陆的实例
Jun 25 Python
Python学习笔记之Zip和Enumerate用法实例分析
Aug 14 Python
在vscode中配置python环境过程解析
Sep 28 Python
基于python实现学生信息管理系统
Nov 22 Python
python绘制雪景图
Dec 16 Python
Python单元测试模块doctest的具体使用
Feb 10 Python
python实现简单颜色识别程序
Feb 19 Python
Django ORM filter() 的运用详解
May 14 Python
Python爬虫实现selenium处理iframe作用域问题
Jan 27 Python
Python OpenCV之常用滤波器使用详解
python Tkinter模块使用方法详解
一行Python命令实现批量加水印
Apr 07 #Python
Python中Matplotlib的点、线形状、颜色以及绘制散点图
详解Python中*args和**kwargs的使用
Apr 07 #Python
Python列表的索引与切片
Apr 07 #Python
Python字符串的转义字符
You might like
模板引擎Smarty深入浅出介绍
2006/12/06 PHP
推荐一篇入门级的Class文章
2007/03/19 PHP
了解Joomla 这款来自国外的php网站管理系统
2010/03/11 PHP
php mongodb操作类 带几个简单的例子
2016/08/25 PHP
ThinkPHP5.0框架使用build 自动生成模块操作示例
2019/04/11 PHP
用dtree实现树形菜单 dtree使用说明
2011/10/17 Javascript
Jquery提交表单 Form.js官方插件介绍
2012/03/01 Javascript
js iframe跨域访问(同主域/非同主域)分别深入介绍
2013/01/24 Javascript
js/jquery去掉空格,回车,换行示例代码
2013/11/05 Javascript
jQuery实用函数用法总结
2014/08/29 Javascript
js操作滚动条事件实例
2015/01/29 Javascript
EasyUi datagrid 实现表格分页
2015/02/10 Javascript
基于JavaScript实现窗口拖动效果
2017/01/18 Javascript
js绑定事件和解绑事件
2017/04/27 Javascript
angular或者js怎么确定选中ul中的哪几个li
2017/08/16 Javascript
通过Python来使用七牛云存储的方法详解
2015/08/07 Python
详解Python的Django框架中manage命令的使用与扩展
2016/04/11 Python
Python基于identicon库创建类似Github上用的头像功能
2017/09/25 Python
python文件名和文件路径操作实例
2017/09/29 Python
学习python中matplotlib绘图设置坐标轴刻度、文本
2018/02/07 Python
TensorFlow利用saver保存和提取参数的实例
2018/07/26 Python
Selenium元素的常用操作方法分析
2018/08/10 Python
python写日志文件操作类与应用示例
2019/07/01 Python
Python发送邮件的实例代码讲解
2019/10/16 Python
Python3连接Mysql8.0遇到的问题及处理步骤
2020/02/17 Python
python向xls写入数据(包括合并,边框,对齐,列宽)
2021/02/02 Python
webapp字号大小跟随系统字号大小缩放的示例代码
2018/12/26 HTML / CSS
美国汽车轮胎和轮毂销售网站:Tire Rack
2018/01/11 全球购物
Groupon西班牙官方网站:在线优惠券和交易,节省高达70%
2021/03/13 全球购物
简单说说tomcat的配置
2013/05/28 面试题
网上蛋糕店创业计划书
2014/01/24 职场文书
《月亮湾》教学反思
2014/04/14 职场文书
岗位聘任协议书
2015/09/21 职场文书
2016学雷锋优秀志愿者事迹材料
2016/02/25 职场文书
Python获取江苏疫情实时数据及爬虫分析
2021/08/02 Python
js 实现Material UI点击涟漪效果示例
2022/09/23 Javascript