利用python中的matplotlib打印混淆矩阵实例


Posted in Python onJune 16, 2020

前面说过混淆矩阵是我们在处理分类问题时,很重要的指标,那么如何更好的把混淆矩阵给打印出来呢,直接做表或者是前端可视化,小编曾经就尝试过用前端(D5)做出来,然后截图,显得不那么好看。。

代码:

import itertools
import matplotlib.pyplot as plt
import numpy as np
 
def plot_confusion_matrix(cm, classes,
       normalize=False,
       title='Confusion matrix',
       cmap=plt.cm.Blues):
 """
 This function prints and plots the confusion matrix.
 Normalization can be applied by setting `normalize=True`.
 """
 if normalize:
  cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  print("Normalized confusion matrix")
 else:
  print('Confusion matrix, without normalization')
 
 print(cm)
 
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(classes))
 plt.xticks(tick_marks, classes, rotation=45)
 plt.yticks(tick_marks, classes)
 
 fmt = '.2f' if normalize else 'd'
 thresh = cm.max() / 2.
 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  plt.text(j, i, format(cm[i, j], fmt),
     horizontalalignment="center",
     color="white" if cm[i, j] > thresh else "black")
 
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.show()
 # plt.savefig('confusion_matrix',dpi=200)
 
cnf_matrix = np.array([
 [4101, 2, 5, 24, 0],
 [50, 3930, 6, 14, 5],
 [29, 3, 3973, 4, 0],
 [45, 7, 1, 3878, 119],
 [31, 1, 8, 28, 3936],
])
 
class_names = ['Buildings', 'Farmland', 'Greenbelt', 'Wasteland', 'Water']
 
# plt.figure()
# plot_confusion_matrix(cnf_matrix, classes=class_names,
#      title='Confusion matrix, without normalization')
 
# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
      title='Normalized confusion matrix')

在放矩阵位置,放一下你的混淆矩阵就可以,当然可视化混淆矩阵这一步也可以直接在模型运行中完成。

补充知识:混淆矩阵(Confusion matrix)的原理及使用(scikit-learn 和 tensorflow)

原理

在机器学习中, 混淆矩阵是一个误差矩阵, 常用来可视化地评估监督学习算法的性能. 混淆矩阵大小为 (n_classes, n_classes) 的方阵, 其中 n_classes 表示类的数量. 这个矩阵的每一行表示真实类中的实例, 而每一列表示预测类中的实例 (Tensorflow 和 scikit-learn 采用的实现方式). 也可以是, 每一行表示预测类中的实例, 而每一列表示真实类中的实例 (Confusion matrix From Wikipedia 中的定义). 通过混淆矩阵, 可以很容易看出系统是否会弄混两个类, 这也是混淆矩阵名字的由来.

混淆矩阵是一种特殊类型的列联表(contingency table)或交叉制表(cross tabulation or crosstab). 其有两维 (真实值 "actual" 和 预测值 "predicted" ), 这两维都具有相同的类("classes")的集合. 在列联表中, 每个维度和类的组合是一个变量. 列联表以表的形式, 可视化地表示多个变量的频率分布.

使用混淆矩阵( scikit-learn 和 Tensorflow)

下面先介绍在 scikit-learn 和 tensorflow 中计算混淆矩阵的 API (Application Programming Interface) 接口函数, 然后在一个示例中, 使用这两个 API 函数.

scikit-learn 混淆矩阵函数 sklearn.metrics.confusion_matrix API 接口

skearn.metrics.confusion_matrix(
 y_true, # array, Gound true (correct) target values
 y_pred, # array, Estimated targets as returned by a classifier
 labels=None, # array, List of labels to index the matrix.
 sample_weight=None # array-like of shape = [n_samples], Optional sample weights
)

在 scikit-learn 中, 计算混淆矩阵用来评估分类的准确度.

按照定义, 混淆矩阵 C 中的元素 Ci,j 等于真实值为组 i , 而预测为组 j 的观测数(the number of observations). 所以对于二分类任务, 预测结果中, 正确的负例数(true negatives, TN)为 C0,0; 错误的负例数(false negatives, FN)为 C1,0; 真实的正例数为 C1,1; 错误的正例数为 C0,1.

如果 labels 为 None, scikit-learn 会把在出现在 y_true 或 y_pred 中的所有值添加到标记列表 labels 中, 并排好序.

Tensorflow 混淆矩阵函数 tf.confusion_matrix API 接口

tf.confusion_matrix(
 labels, # 1-D Tensor of real labels for the classification task
 predictions, # 1-D Tensor of predictions for a givenclassification
 num_classes=None, # The possible number of labels the classification task can have
 dtype=tf.int32, # Data type of the confusion matrix 
 name=None, # Scope name
 weights=None, # An optional Tensor whose shape matches predictions
)

Tensorflow tf.confusion_matrix 中的 num_classes 参数的含义, 与 scikit-learn sklearn.metrics.confusion_matrix 中的 labels 参数相近, 是与标记有关的参数, 表示类的总个数, 但没有列出具体的标记值. 在 Tensorflow 中一般是以整数作为标记, 如果标记为字符串等非整数类型, 则需先转为整数表示. 如果 num_classes 参数为 None, 则把 labels 和 predictions 中的最大值 + 1, 作为num_classes 参数值.

tf.confusion_matrix 的 weights 参数和 sklearn.metrics.confusion_matrix 的 sample_weight 参数的含义相同, 都是对预测值进行加权, 在此基础上, 计算混淆矩阵单元的值.

使用示例

#!/usr/bin/env python
# -*- coding: utf8 -*-
"""
Author: klchang
Description: 
A simple example for tf.confusion_matrix and sklearn.metrics.confusion_matrix.
Date: 2018.9.8
"""
from __future__ import print_function
import tensorflow as tf
import sklearn.metrics
 
y_true = [1, 2, 4]
y_pred = [2, 2, 4]
 
# Build graph with tf.confusion_matrix operation
sess = tf.InteractiveSession()
op = tf.confusion_matrix(y_true, y_pred)
op2 = tf.confusion_matrix(y_true, y_pred, num_classes=6, dtype=tf.float32, weights=tf.constant([0.3, 0.4, 0.3]))
# Execute the graph
print ("confusion matrix in tensorflow: ")
print ("1. default: \n", op.eval())
print ("2. customed: \n", sess.run(op2))
sess.close()
 
# Use sklearn.metrics.confusion_matrix function
print ("\nconfusion matrix in scikit-learn: ")
print ("1. default: \n", sklearn.metrics.confusion_matrix(y_true, y_pred))
print ("2. customed: \n", sklearn.metrics.confusion_matrix(y_true, y_pred, labels=range(6), sample_weight=[0.3, 0.4, 0.3]))

以上这篇利用python中的matplotlib打印混淆矩阵实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python开发的小球完全弹性碰撞游戏代码
Oct 15 Python
Python实现栈的方法
May 26 Python
浅谈python jieba分词模块的基本用法
Nov 09 Python
python 日期排序的实例代码
Jul 11 Python
python查找重复图片并删除(图片去重)
Jul 16 Python
对Python中一维向量和一维向量转置相乘的方法详解
Aug 26 Python
Python列表原理与用法详解【创建、元素增加、删除、访问、计数、切片、遍历等】
Oct 30 Python
关于多元线性回归分析——Python&SPSS
Feb 24 Python
pycharm中如何自定义设置通过“ctrl+滚轮”进行放大和缩小实现方法
Sep 16 Python
Python importlib模块重载使用方法详解
Oct 13 Python
python 操作excel表格的方法
Dec 05 Python
解决pycharm下载库时出现Failed to install package的问题
Sep 04 Python
Python SMTP配置参数并发送邮件
Jun 16 #Python
基于matplotlib中ion()和ioff()的使用详解
Jun 16 #Python
Python数据相关系数矩阵和热力图轻松实现教程
Jun 16 #Python
matplotlib.pyplot.matshow 矩阵可视化实例
Jun 16 #Python
使用python matploblib库绘制准确率,损失率折线图
Jun 16 #Python
为什么称python为胶水语言
Jun 16 #Python
在Keras中利用np.random.shuffle()打乱数据集实例
Jun 15 #Python
You might like
不错的一篇面向对象的PHP开发模式(简写版)
2007/03/15 PHP
PHP 字符串加密函数(在指定时间内加密还原字符串,超时无法还原)
2010/04/28 PHP
大家都应该掌握的PHP关联数组使用技巧
2015/12/25 PHP
php文件后缀不强制为.php的实操方法
2019/09/18 PHP
js获取变量
2006/08/24 Javascript
JavaScript在IE和Firefox浏览器下的7个差异兼容写法小结
2010/06/18 Javascript
基于jquery的内容循环滚动小模块(仿新浪微博未登录首页滚动微博显示)
2011/03/28 Javascript
左侧是表头的JS表格控件(自写,网上没有的)
2013/06/04 Javascript
javascript使用onclick事件改变选中行的颜色
2013/12/30 Javascript
js实现带圆角的多级下拉菜单效果
2015/08/28 Javascript
动态的9*9乘法表效果的实现代码
2016/05/16 Javascript
JSON格式的时间/Date(2367828670431)/格式转为正常的年-月-日 格式的代码
2016/07/27 Javascript
jQuery插件EasyUI实现Layout框架页面中弹出窗体到最顶层效果(穿越iframe)
2016/08/05 Javascript
Jquery针对tr td的一些实用操作方法(必看篇)
2016/10/05 Javascript
JavaScript用二分法查找数据的实例代码
2017/06/17 Javascript
jQuery实现的form转json经典示例
2017/10/10 jQuery
AngularJS中的作用域实例分析
2018/05/16 Javascript
微信小程序实现签到功能
2018/10/31 Javascript
详解javascript void(0)
2020/07/13 Javascript
python中定义结构体的方法
2013/03/04 Python
利用Python获取赶集网招聘信息前篇
2016/04/18 Python
python字符串常用方法
2018/06/14 Python
基于python3实现socket文件传输和校验
2018/07/28 Python
python自动化测试之如何解析excel文件
2019/06/27 Python
Python列表删除元素del、pop()和remove()的区别小结
2019/09/11 Python
Django实现从数据库中获取到的数据转换为dict
2020/03/27 Python
英国最大的在线时尚眼镜店:Eyewearbrands
2019/03/12 全球购物
英国家电购物网站:Sonic Direct
2019/03/26 全球购物
高中生期末评语大全
2014/01/28 职场文书
行政工作个人的自我评价
2014/02/13 职场文书
药品业务员岗位职责
2014/04/17 职场文书
感恩小明星事迹材料
2014/05/23 职场文书
2015年发展党员工作总结报告
2015/03/31 职场文书
土木工程生产实习心得体会
2016/01/22 职场文书
2019大学竞选班长发言稿
2019/06/27 职场文书
mysql5.5中文乱码问题解决的有用方法
2022/05/30 MySQL