【问题标题】:How to correctly reshape the multiclass output of predict_proba of a sklearn classifier?如何正确重塑 sklearn 分类器的 predict_proba 的多类输出?
【发布时间】:2021-07-10 06:08:57
【问题描述】:

我有 10 个班级的多班级问题。 使用任何带有 predict_proba 的 sklearn 分类器,我得到的输出是

(n_classes, n_samples, n_classes_probability_1_or_0)

在我的情况下(10, 4789, 2)

现在我会做二进制分类

model.predict_proba(X)[:, 1]

我认为:

pred = np.array(model.predict_proba(X))
pred = pred.reshape(-1, 10, 2)[:, :, 1]

会做同样的事情,但订单完全关闭。

现在y[:, class] 对应于pred[class, :, 1]

我知道我想错了形状,但很遗憾我看不到。

如何正确重塑它? 目标是在 roc_auc_score 指标中使用它 我想要(instances, classes_probabilities = 1)的形状

你能帮忙吗? 提前谢谢!

【问题讨论】:

    标签: python numpy scikit-learn classification multilabel-classification


    【解决方案1】:

    如果您提到您正在使用MultiOutputClassifier,这将很有用,因为 scikit learn 中的大多数多类分类器不会返回类似您的东西,因此使用示例数据集:

    import numpy as np
    import pandas as pd
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.multioutput import MultiOutputClassifier
    from sklearn import preprocessing
    
    lb = preprocessing.LabelBinarizer()
    
    from sklearn.datasets import make_classification
    X, y = make_classification(n_samples=500,n_classes=10,n_informative=10,n_clusters_per_class=1)
    y = lb.fit_transform(y)
    

    设置分类器

    forest = RandomForestClassifier(n_estimators=10, random_state=1)
    model = MultiOutputClassifier(forest, n_jobs=-1)
    model.fit(X, y)
    

    你不需要考虑重塑它,只需提取值:

    pred = np.array(model.predict_proba(X))
    

    就像你之前所做的那样,这将对应于每一行是一个类,每一列是你的观察:

    pred[:,:,1].shape
    (10, 500)
    

    要得到你的概率,只需转置:

    prob1 = pred[:,:, 1].T
    
    prob1[:2]
    array([[0.9, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
           [0.1, 0. , 0.1, 0. , 0.7, 0. , 0.1, 0. , 0.1, 0. ]])
    

    如果我们实际提取它并堆叠比较:

    prob2 = np.hstack([i[:,1].reshape(-1,1) for i in model.predict_proba(X)])
    array([[0.9, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0.1, 0. , 0.1, 0. , 0.7, 0. , 0.1, 0. , 0.1, 0. ]])
    

    【讨论】:

      猜你喜欢
      • 2018-08-14
      • 2017-11-04
      • 2014-01-02
      • 2020-04-11
      • 2015-07-19
      • 1970-01-01
      • 2018-03-03
      • 2018-09-11
      • 1970-01-01
      相关资源
      最近更新 更多