【问题标题】:Should I perform Cross Validation first and then do grid search?我应该先执行交叉验证,然后再进行网格搜索吗?
【发布时间】:2020-05-28 22:46:16
【问题描述】:

我是机器学习领域的新手。我的问题如下:我已经建立了一个模型,并且我正在尝试优化这样的模型。通过做一些研究,我发现交叉验证可以用来帮助我避免过度拟合的模型。此外,Gridsearchcv 可用于帮助我优化此类模型的参数并最终确定可能的最佳参数。

现在我的问题是我应该先进行交叉验证,然后使用网格搜索来识别最佳参数,还是使用 GridsearchCV 就足够了,因为它本身执行交叉验证?

【问题讨论】:

    标签: machine-learning scikit-learn cross-validation grid-search


    【解决方案1】:

    正如@Noki 所建议的,您可以在 Grid Search CV 中使用 cv 参数。

    GridSearchCV(estimator, param_grid, scoring=None, n_jobs=None, iid='deprecated', 
    refit=True, cv=None, verbose=0, 
    pre_dispatch='2*n_jobs',error_score=nan,return_train_score=False)
    

    文档还明确指出,如果是分类问题,它将自动确保它是分层的。

    对于整数/无输入,如果估计器是分类器并且 y 是二进制或 多类,使用 StratifiedKFold。在所有其他情况下,使用 KFold。

    但是,我想补充一点: 您可以根据您的 Y_target 变量的值计数使您的 K-folds 动态化。 您不能将 K-fold 中频率的最低计数设为 1,它会在训练时引发错误。我碰巧遇到了这种情况。使用下面的代码 sn-p 来帮助你。

    例如

    import pandas as pd
    Y_target=pd.Series([0,1,1,1,1,0,0,0,6,6,6,6,6,6,6,6,6])
    
    if Y_target.value_counts().iloc[-1]<2:
        raise Exception ("No value can have frequency count as 1 in Y-target")
    else:
        Kfold=Y_target.value_counts().iloc[-1]
    

    然后您可以在网格搜索中将 Kfold 分配给您的 cv 参数

    【讨论】:

    • 我的问题是一个多类分类问题,所以最好使用分层的 k 折叠而不是正常的 k 折叠?如果是这样,为什么会这样?因为我找不到任何相关的东西。
    • 是的,最好使用 Stratified K-fold,因为它可以确保您的训练集和测试集之间的数据比率保持不变。
    • 如果您在 GridSearch CV 中这样做,它会自动得到处理。使用上面的代码 sn-p 来确定折叠。您取相同的值或任何小于变量 K-fold 值的值。没有问题。
    【解决方案2】:

    Cross validation with test data set

    我的建议,如果你的数据集足够大:

    1. 将您的数据集拆分为训练和测试子集。
    2. 对您的训练数据集执行GridSearchCV
    3. 在您的测试子集上评估您的最佳模型(来自 GridSearchCV)。

    【讨论】:

    • 使用交叉验证将我的数据拆分为训练和测试是合适的,而不是使用拆分,例如 80% 训练和 20% 测试?
    • 如果您有足够的数据,请将 GridSearchCV 之前的数据集拆分为训练和测试。例如,请参见此处:stats.stackexchange.com/questions/148688/…
    • 感谢您的信息,然后我将使用交叉验证拆分我的数据,我将在此类数据上训练此类模型并最终使用 GridSearchCV(通过指定相同的折叠数)来识别不同的参数模型可以采用的。您认为这是一个好方法吗?
    【解决方案3】:

    现在我的问题是我应该先进行交叉验证,然后使用网格搜索来识别最佳参数,还是使用 GridsearchCV 就足够了,因为它本身执行交叉验证?

    第二个。 GridSearchCV 使用交叉验证拆分策略来选择最佳参数。如果您阅读scikit-learn documentation,有一个名为“cv”的参数,它默认定义了 5 折交叉验证。如果你需要使用其他的交叉验证策略,你可以给它一个int,交叉验证生成器或者一个iterable

    【讨论】:

    • 那么考虑到网格搜索也可以执行 CV,就不需要事先拆分数据吧?
    • 是的,如果您想使用这 20 个作为验证数据集并且更安全地使您的模型能够很好地泛化,您最多可以进行 80-20 拆分(或 10 或其他)。跨度>
    猜你喜欢
    • 2021-09-10
    • 2018-05-10
    • 2017-06-19
    • 2020-02-19
    • 2019-09-29
    • 2021-12-11
    • 1970-01-01
    • 2019-03-30
    • 2015-09-18
    相关资源
    最近更新 更多