【问题标题】:Plot confusion matrix sklearn with multiple labels绘制具有多个标签的混淆矩阵 sklearn
【发布时间】:2016-08-19 07:56:40
【问题描述】:

我正在为多标签数据绘制混淆矩阵,其中标签如下所示:

标签1:1、0、0、0

标签2:0、1、0、0

标签3:0、0、1、0

标签4:0、0、0、1

我能够使用以下代码成功分类。 我只需要一些帮助来绘制混淆矩阵。

    for i in range(4):
        y_train= y[:,i]
        print('Train subject %d, class %s' % (subject, cols[i]))
        lr.fit(X_train[::sample,:],y_train[::sample])
        pred[:,i] = lr.predict_proba(X_test)[:,1]

我使用以下代码打印混淆矩阵,但它总是返回一个 2X2 矩阵

prediction = lr.predict(X_train)

print(confusion_matrix(y_train, prediction))

【问题讨论】:

  • 我认为 OP 意味着多类而不是多标签。

标签: python machine-learning scikit-learn confusion-matrix


【解决方案1】:

我找到了一个可以绘制从sklearn 生成的混淆矩阵的函数。

import numpy as np


def plot_confusion_matrix(cm,
                          target_names,
                          title='Confusion matrix',
                          cmap=None,
                          normalize=True):
    """
    given a sklearn confusion matrix (cm), make a nice plot

    Arguments
    ---------
    cm:           confusion matrix from sklearn.metrics.confusion_matrix

    target_names: given classification classes such as [0, 1, 2]
                  the class names, for example: ['high', 'medium', 'low']

    title:        the text to display at the top of the matrix

    cmap:         the gradient of the values displayed from matplotlib.pyplot.cm
                  see http://matplotlib.org/examples/color/colormaps_reference.html
                  plt.get_cmap('jet') or plt.cm.Blues

    normalize:    If False, plot the raw numbers
                  If True, plot the proportions

    Usage
    -----
    plot_confusion_matrix(cm           = cm,                  # confusion matrix created by
                                                              # sklearn.metrics.confusion_matrix
                          normalize    = True,                # show proportions
                          target_names = y_labels_vals,       # list of names of the classes
                          title        = best_estimator_name) # title of graph

    Citiation
    ---------
    http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

    """
    import matplotlib.pyplot as plt
    import numpy as np
    import itertools

    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]


    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")


    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
    plt.show()

看起来像这样

【讨论】:

    【解决方案2】:

    这对我来说效果最好:

    from sklearn.metrics import multilabel_confusion_matrix
    y_unique = y_test.unique()
    mcm = multilabel_confusion_matrix(y_test, y_pred, labels = y_unique)
    mcm
    

    【讨论】:

      【解决方案3】:

      我发现这在sklearn 的存储库中仍然是一个未解决的问题:

      https://github.com/scikit-learn/scikit-learn/issues/3452

      但是已经有一些尝试实施它。来自同一个 #3452 线程问题:

      https://github.com/Magellanea/scikit-learn/commit/514287c1d5dad2f0ab4918dc4da5cf7053fe6734#diff-b04acd877dd793f28ae7be13a999ed88R187

      您可以检查函数中提出的代码,看看是否符合您的需求。

      【讨论】:

      • 我用 multilabel_confusion_matrix 替换了confusion_matrix,它给出了一个错误,即未定义名称'multilabel_confusion_matrix'。这个问题有解决方法吗?这个问题似乎在 Github 上公开。
      • 正如我所说:“仍然是一个悬而未决的问题”。我只是给了你代码的链接,以防你想尝试使用它。但它不在 sklearn 的代码中,这就是为什么它说它没有定义。如果您想使用(我没有尝试过),您应该在您自己的代码中包含multilabel_confusion_matrix 中的所有代码并调用该函数。要小心,因为这是 2014 年以来的一个悬而未决的问题,而且它仍然是一个悬而未决的问题这一事实可能表明它不是一个微不足道的问题。我只是给了你一个指针,以防你想自己尝试并自己解决。祝你好运!
      【解决方案4】:
      from sklearn.metrics import multilabel_confusion_matrix
      
      mul_c = multilabel_confusion_matrix(
          test_Y,
          pred_k,
          labels=["benign", "dos","probe","r2l","u2r"])
      mul_c
      

      【讨论】:

      • 虽然此代码可能会为问题提供解决方案,但最好添加有关其工作原理/方式的上下文。这可以帮助未来的用户学习并将这些知识应用到他们自己的代码中。解释代码时,您也可能会以赞成票的形式从用户那里获得积极的反馈。
      【解决方案5】:

      我找到了一个使用 sklearn 和 seaborn 库的简单解决方案。

      from sklearn.metrics import confusion_matrix, classification_report
      from matplotlib import pyplot as plt
      import seaborn as sns
      
      def plot_confusion_matrix(y_test,y_scores, classNames):
          y_test=np.argmax(y_test, axis=1)
          y_scores=np.argmax(y_scores, axis=1)
          classes = len(classNames)
          cm = confusion_matrix(y_test, y_scores)
          print("**** Confusion Matrix ****")
          print(cm)
          print("**** Classification Report ****")
          print(classification_report(y_test, y_scores, target_names=classNames))
          con = np.zeros((classes,classes))
          for x in range(classes):
              for y in range(classes):
                  con[x,y] = cm[x,y]/np.sum(cm[x,:])
      
          plt.figure(figsize=(40,40))
          sns.set(font_scale=3.0) # for label size
          df = sns.heatmap(con, annot=True,fmt='.2', cmap='Blues',xticklabels= classNames , yticklabels= classNames)
          df.figure.savefig("image2.png")
      
      classNames = ['A', 'B', 'C', 'D', 'E'] 
      plot_confusion_matrix(y_test,y_scores, classNames) 
      #y_test is your ground truth
      #y_scores is your predicted probabilities
      

      【讨论】:

        猜你喜欢
        • 2013-10-14
        • 2019-05-22
        • 2020-10-24
        • 2020-01-22
        • 2021-03-13
        • 2021-01-17
        • 2018-11-06
        • 2016-01-31
        相关资源
        最近更新 更多