【问题标题】:Cross Validation sklearn - How the splits are performed?交叉验证 sklearn - 如何执行拆分?
【发布时间】:2016-04-13 10:11:52
【问题描述】:

我目前正在处理一个分类问题,并且对sklearn/scikit-learn Python 模块的交叉验证功能有疑问。考虑以下调用:

cv_scores = cross_validation.cross_val_score(rfc, X, y, cv=self.cv_folds, n_jobs = 1)

self.cv_folds 是一个数字,例如5. 该函数实际上返回一个数组,其中包含交叉验证的每一折的分数。但现在我需要知道 该函数是如何执行拆分的。这意味着,哪些记录已分配给每次迭代的测试和训练集。为了更清楚,让我们考虑一个小的示例数据集:

(1)  0 1:9151.57142857 2:158.0 3:0.0136674259681 4:5.0 5:438.0 6:6.0  7:9.25388888889
(2)  1 1:3884.8 2:20338.0 3:0.0280373831776 4:194.0 5:320.0 6:9.0 7:42.8808333333
(3)  0 1:5219.5 2:241.0 3:0.00171821305842 4:55.0 5:1745.0 6:3.0 7:42.8808333333
(4)  0 1:1386.0 2:2125.0 3:0.0161290322581 4:315.0 5:309.0 6:5.0 7:14.8722222222
(5)  1 1:5508.375 2:27.0 3:0.00302245250432 4:1231.0 5:2315.0 6:7.0 7:591.213611111
(6)  1 1:12488.0 2:404.0 3:0.020942408377 4:31.0 5:190.0 6:4.0 7:9.25388888889
(7)  1 1:1748.4 2:0.0 3:0.00293685756241 4:376.0 5:1361.0 6:4.0 7:96.5372222222
(8)  1 1:3401.25 2:476.0 3:0.0714285714286 4:16.0 5:41.0 6:3.0 7:3.19722222222
(9)  1 1:2748.4 2:614.0 3:0.25 4:3.0 5:15.0 6:4.0 7:3.19722222222    
(10) 1 1:1386.0 2:2125.0 3:0.0161290322581 4:47.0 5:309.0 6:5.0 7:14.8722222222

(X) 表示行号,第一个值是类标签,从 1-7 的值是特征索引,每个后面跟着它的值。现在我想知道函数的确切拆分策略。为了更清楚地说明,以下示例显示了如何将数据拆分为每次迭代的测试集和训练集的两种不同方式:

示例 1:

迭代 1:(1) - (2) 进行测试

迭代 2:(3) - (4) 进行测试

迭代 3:(5) - (6) 进行测试

...

示例 2

迭代 1:(1) 和 (3) 进行测试

迭代 2:(2) 和 (4) 进行测试

迭代3:(5)和(7)进行测试

...

有人知道函数使用的确切拆分策略吗?或者任何人都可以声明一个函数来查看这些拆分,而不仅仅是查看结果?

提前感谢您的时间和精力。

【问题讨论】:

    标签: python scikit-learn cross-validation


    【解决方案1】:

    如果cv_folds是一个cv-object,那么看看list(self.cv_folds),你会发现一个元组列表[(train1, test1), (train2, test2), ...]

    如果self.cv_folds 只是一个数字,则考虑明确设置交叉验证迭代器,例如如下:

    from sklearn.cross_validation import KFold, StratifiedKFold
    ## Choose one of the two next lines
    cv = KFold(self.cv_folds)   # for regression
    cv = StratifiedKFold(y, self.cv_folds)   # for classification
    
    cv_scores = cross_validation.cross_val_score(rfc, X, y, cv=cv)
    

    现在,使用list(cv),您可以恢复训练/测试拆分的所有索引。

    请注意,默认的交叉验证迭代器取决于您的估算器(分类器或回归器)的性质,因此如果您需要这种详细程度,最好明确说明。

    【讨论】:

    • 第一个cv_folds 是一个数字,我将此信息添加到我的帖子中。非常感谢您,这看起来与我正在寻找的解决方案一模一样。我会试试这个,如果我还有任何问题,我可能会添加另一条评论。谢谢。
    猜你喜欢
    • 2014-07-27
    • 1970-01-01
    • 2013-12-13
    • 2012-12-31
    • 2018-08-16
    • 2019-08-30
    • 2015-06-11
    • 2017-03-15
    • 2020-07-10
    相关资源
    最近更新 更多