可视化混淆矩阵

混淆矩阵是一个表, 我们用它来理解分类模型的性能. 这有助于我们了解我们如何将测试数据分类到不同的类. 当我们想要微调我们的算法时, 我们需要了解在进行这些更改之前数据如何被错误分类. 一些类比其他类更糟, 混淆矩阵将帮助我们理解这一点. 让我们来看下图:

在上面的图表中, 我们可以看到如何将数据分类到不同的分类中. 理想情况下, 我们希望所有的非对角元素都为0. 这将表示完美的分类! 让我们考虑分类0.总的来说, 52项实际上属于类0.如果我们总结第一行中的数字, 我们得到52. 现在, 这些项目中的45个被正确地预测, 但是我们的分类器说, 其中四个属于类别1, 其中三个属于类别2. 我们可以对剩余的两行应用相同的分析. 一个有趣的事情要注意的是, 来自类1的11个项目被错误分类为类0.这构成了这个类中的大约16%的数据点. 这是一个insight, 我们可以用来优化我们的模型.

怎么做...?

  • 我们将使用我们已经提供给您的confusion_matrix.py文件作为参考. 让我们看看如何从我们的数据中提取混淆矩阵:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

def plot_confusion_matrix(confusion_mat):
    plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.cm.gray)
    plt.title('Confusion matrix')
    plt.colorbar()
    tick_marks = np.arange(4)
    plt.xticks(tick_marks, tick_marks)
    plt.yticks(tick_marks, tick_marks)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()


y_true = [1, 0, 0, 2, 1, 0, 3, 3, 3]
y_pred = [1, 1, 0, 2, 1, 0, 1, 3, 3]
confusion_mat = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(confusion_mat)

我们在这里使用一些示例数据. 我们有四个类, 值的范围从0到3.我们也预测标签. 我们使用confusion_matrix方法提取混淆矩阵并绘制它.

  • 如下图:对角线的颜色很strong, 我们希望它们更strong. 黑色表示零. 在非对角空格中有一对灰色的颜色, 表示错误分类. 例如, 当实际标签为0时, 预测的标签为1, 我们可以在第一行中看到. 事实上, 在第二列包含非零的三行的意义上, 所有错误分类属于类-1. 从图中很容易看到这一点.

results matching ""

    No results matching ""