【问题标题】:Get multi-class confusion matrix equal to number of class labels获得等于类标签数量的多类混淆矩阵
【发布时间】:2019-01-20 20:36:19
【问题描述】:

我在sklearn 中训练了随机森林分类器来预测多类分类问题。

我的数据集有四个类别标签。但是我的代码创建了 2x2 混淆矩阵

y_predict = rf.predict(X_test)
conf_mat = sklearn.metrics.confusion_matrix(y_test, y_predict)
print(conf_mat)

输出:

[[0,   0]

 [394, 39]]

如何获得 4x4 混淆矩阵来分析 TP、TN、FP、FN。

【问题讨论】:

  • 取决于您的y_test。您的y_test 是否包含所有 4 个标签?

标签: scikit-learn classification confusion-matrix


【解决方案1】:

来自
http://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html 的文档

y_true = ["cat", "ant", "cat", "cat", "ant", "bird"]
y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"]
confusion_matrix(y_true, y_pred, labels=["ant", "bird", "cat"])

结果:

array([[2, 0, 0], 
       [0, 0, 1], 
       [1, 0, 2]])

【讨论】:

    猜你喜欢
    • 2019-03-01
    • 2021-11-11
    • 2014-10-19
    • 2019-05-22
    • 2020-01-22
    • 2020-10-24
    • 2018-11-06
    • 2021-07-02
    • 2018-10-21
    相关资源
    最近更新 更多