【发布时间】:2017-01-08 10:31:12
【问题描述】:
我需要使用 3 折交叉验证来训练 Random Forest classifier。对于每个样本,当它恰好在测试集中时,我需要检索预测概率。
我正在使用 scikit-learn 版本 0.18.dev0。
此新版本添加了使用方法cross_val_predict() 和附加参数method 来定义估计器需要哪种预测的功能。
在我的例子中,我想使用predict_proba() 方法,它在多类场景中返回每个类的概率。
但是,当我运行该方法时,我会得到预测概率矩阵,其中每一行代表一个样本,每一列代表特定类的预测概率。
问题是方法没有指明每一列对应哪个类。
我需要的值是相同的(在我的情况下使用RandomForestClassifier)在属性classes_中返回定义为:
classes_ : 形状数组 = [n_classes] 或此类数组的列表 类标签(单输出问题),或类标签数组列表(多输出问题)。
predict_proba() 需要它,因为在其文档中写道:
类的顺序与属性classes_中的顺序相对应。
一个最小的例子如下:
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_predict
clf = RandomForestClassifier()
X = np.random.randn(10, 10)
y = y = np.array([1] * 4 + [0] * 3 + [2] * 3)
# how to get classes from here?
proba = cross_val_predict(estimator=clf, X=X, y=y, method="predict_proba")
# using the classifier without cross-validation
# it is possible to get the classes in this way:
clf.fit(X, y)
proba = clf.predict_proba(X)
classes = clf.classes_
【问题讨论】:
-
对于您无法访问
y的情况? -
@juanpa.arrivillaga 我可以访问
y,但我不知道标签的排序顺序。我可能推测它们是按升序排列的,但我不完全确定。 -
抱歉,我还没有 0.18,所以我无法对此进行测试,这可能很明显,但
clf对象现在不包含classes_属性吗?跨度> -
Np。不,我对其进行了测试,但它没有
classes_属性。也许分类器被复制到cross_val_predict()方法中,而原始的没有更新。 -
你是对的!它does.
标签: python scikit-learn cross-validation