【问题标题】:Nested cross-validation with GroupKFold with sklearnGroupKFold 和 sklearn 的嵌套交叉验证
【发布时间】:2021-02-10 20:33:47
【问题描述】:

在我的数据中,多个条目对应于一个主题,我不会在训练集和测试集之间混合这些条目。出于这个原因,我查看了 GroupKFold 折叠迭代器,根据 sklearn 文档,它是一个 “具有非重叠组的 K 折叠迭代器变体”。 因此,我想使用GroupKFold 实现嵌套交叉验证来拆分测试和训练集。

我从this question 中给出的模板开始。但是,我在网格实例上调用fit 方法时遇到错误,说groups 的形状与Xy 的形状不同。为了解决这个问题,我也使用火车索引对groups 进行了切片。

这个实现是否正确?我最关心的是不要在训练集和测试集之间混合来自同一组的数据。

inner_cv = GroupKFold(n_splits=inner_fold)
outer_cv = GroupKFold(n_splits=out_fold)


for train_index, test_index in outer_cv.split(x, y, groups=groups):
    x_train, x_test = x[train_index], x[test_index]
    y_train, y_test = y[train_index], y[test_index]

    grid = RandomizedSearchCV(estimator=model,
                                param_distributions=parameters_grid,
                                cv=inner_cv,
                                scoring=get_scoring(),
                                refit='roc_auc_scorer',
                                return_train_score=True,
                                verbose=1,
                                n_jobs=jobs)
    grid.fit(x_train, y_train, groups=groups[train_index])
    prediction = grid.predict(x_test)

【问题讨论】:

  • 您链接的问题似乎不正确。这是一个javascript问题。
  • 谢谢@bernie,我已经修复了链接

标签: python machine-learning scikit-learn cross-validation k-fold


【解决方案1】:

您可以确认代码是否按照您的预期执行的一种方法(即不混合组之间的数据)是您不能将 GroupKFold 对象而是 GroupKFold.split 的输出(索引)传递给 RandomizedSearchCV。例如

grid = RandomizedSearchCV(estimator=model,
                            param_distributions=parameters_grid,
                            cv=inner_cv.split(
                              x_train, y_train, groups=groups[train_index]),
                            scoring=get_scoring(),
                            refit='roc_auc_scorer',
                            return_train_score=True,
                            verbose=1,
                            n_jobs=jobs)
grid.fit(x_train, y_train)

我相信这会导致相同的拟合结果,在这里您已经明确给出了交叉验证每一折的训练/验证指数。

据我所知,这两种方法是等效的,但我认为您的示例编写方式更优雅,因为您没有提供两次 x_trainy_train

使用train_index 切片groups 似乎是正确的,因为您只是将切片的xy 变量传递给fit 方法。我必须提醒自己,内部交叉验证将在外部交叉验证操作的训练子集上进行交叉验证。

【讨论】:

    猜你喜欢
    • 2020-07-14
    • 2012-12-31
    • 2017-03-16
    • 1970-01-01
    • 2018-02-03
    • 2017-12-22
    • 2018-08-16
    • 2018-04-26
    • 2021-03-25
    相关资源
    最近更新 更多