【问题标题】:How to work out how many models GridSearchCV will train?如何计算出 GridSearchCV 将训练多少个模型?
【发布时间】:2023-04-03 17:43:01
【问题描述】:

您如何计算出 SKLearn 的 GridSearchCV 将训练多少个模型?就我而言,我使用以下参数:

learning_rate_range = [0.01, 0.05, 0.1]
max_depth_range = [3, 4, 5, 6, 7]
min_child_weight_range = [6, 7, 8]
subsample_range = [0.6, 0.7, 0.8, 0.9]
colsample_range = [0.7, 0.8, 0.9]

例如,如果您使用 3 折交叉验证,总共将训练多少个模型,解决此问题的一般方法是什么?

【问题讨论】:

  • 当使用带有详细参数的 GridSearchCV 时,您看到的第一行将是您询问的特定数字。

标签: python scikit-learn grid-search


【解决方案1】:

根据文档:“GridSearchCV 详尽地考虑了所有参数组合,而 RandomizedSearchCV 可以从具有指定分布的参数空间中采样给定数量的候选者。”。

http://scikit-learn.org/stable/modules/grid_search.html#grid-search

还有一个实际使用的 GridSearchCV 示例:

http://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_digits.html#sphx-glr-auto-examples-model-selection-plot-grid-search-digits-py

如果您将上述所有参数传递到一个字典中,您将获得 3x5x3x4x3 个网格点,每个点将被交叉验证 3 次。

【讨论】:

    【解决方案2】:

    正如@KRKirov 所说,参数总数只是每个参数的各个级别的乘积。 SciKit learn 提供了一个简单的方法来知道参数的总数,如下所示:

    from sklearn.model_selection import ParameterGrid
    
    parameters = {
    learning_rate_range: [0.01, 0.05, 0.1]
    max_depth_range: [3, 4, 5, 6, 7]
    min_child_weight_range: [6, 7, 8]
    subsample_range: [0.6, 0.7, 0.8, 0.9]
    colsample_range: [0.7, 0.8, 0.9]
    }
    
    grid = ParameterGrid(parameters)
    # python 3.6+ for the f format
    print (f"The total number of parameters-combinations is: {len(grid)}")
    

    请记住,每个参数组合都会执行 5 次以进行交叉验证。因此,总执行次数为5 * len(grid)

    【讨论】:

      猜你喜欢
      • 2018-09-28
      • 2020-06-17
      • 2014-10-13
      • 2019-09-03
      • 2016-08-28
      • 2019-04-01
      • 1970-01-01
      • 2021-11-27
      • 2014-06-16
      相关资源
      最近更新 更多