【问题标题】:Multilabel-indicator is not supported for confusion matrix混淆矩阵不支持多标签指示符
【发布时间】:2018-04-07 19:34:00
【问题描述】:

multilabel-indicator is not supported 是我在尝试运行时收到的错误消息:

confusion_matrix(y_test, predictions)

y_test 是一个DataFrame,它的形状:

Horse | Dog | Cat
1       0     0
0       1     0
0       1     0
...     ...   ...

predictionsnumpy array

[[1, 0, 0],
 [0, 1, 0],
 [0, 1, 0]]

我已经搜索了一些错误消息,但还没有真正找到我可以应用的东西。有什么提示吗?

【问题讨论】:

  • 只是想为正在寻找正确方法来可视化多标签分类器错误的任何人加两分钱:您的预测数组看起来像来自 multiclass 分类器。混淆矩阵不适合同时预测多个标签的 multilabel 分类。

标签: python numpy scikit-learn classification


【解决方案1】:

不,您对confusion_matrix 的输入必须是预测列表,而不是 OHE(一种热编码)。在您的y_testy_pred 上拨打argmax,您应该会得到您所期望的。

confusion_matrix(
    y_test.values.argmax(axis=1), predictions.argmax(axis=1))

array([[1, 0],
       [0, 2]])

【讨论】:

    【解决方案2】:

    混淆矩阵采用标签向量(不是 one-hot 编码)。你应该跑

    confusion_matrix(y_test.values.argmax(axis=1), predictions.argmax(axis=1))
    

    【讨论】:

      【解决方案3】:
      from sklearn.metrics import confusion_matrix
      
      predictions_one_hot = model.predict(test_data)
      cm = confusion_matrix(labels_one_hot.argmax(axis=1), predictions_one_hot.argmax(axis=1))
      print(cm)
      

      输出会是这样的:

      [[298   2  47  15  77   3  49]
       [ 14  31   2   0   5   1   2]
       [ 64   5 262  22  94  38  43]
       [ 16   1  20 779  15  14  34]
       [ 49   0  71  33 316   7 118]
       [ 14   0  42  23   5 323   9]
       [ 20   1  27  32  97  13 436]]
      

      【讨论】:

        【解决方案4】:

        如果你有 numpy.ndarray 你可以试试下面的

        
        import seaborn as sns
        
        T5_lables = ['4TCM','WCM','WSCCM','IWCM','CCM']    
        
        ax= plt.subplot()
        
        cm = confusion_matrix(np.asarray(Y_Test).argmax(axis=1), np.asarray(Y_Pred).argmax(axis=1))
        sns.heatmap(cm, annot=True, fmt='g', ax=ax);  #annot=True to annotate cells, ftm='g' to disable scientific notation
        
        # labels, title and ticks
        ax.set_xlabel('Predicted labels');ax.set_ylabel('True labels'); 
        ax.set_title('Confusion Matrix'); 
        ax.xaxis.set_ticklabels(T5_lables); ax.yaxis.set_ticklabels(T5_lables);
        
        
        

        【讨论】:

          猜你喜欢
          • 2019-02-12
          • 2020-01-22
          • 1970-01-01
          • 2018-11-06
          • 2021-11-11
          • 2020-07-31
          • 2021-08-15
          • 2020-08-30
          • 2019-05-22
          相关资源
          最近更新 更多