【问题标题】:How to determine best parameters and best score for each scoring metric in GridSearchCV如何确定 GridSearchCV 中每个评分指标的最佳参数和最佳分数
【发布时间】:2020-11-09 21:39:47
【问题描述】:

我正在尝试评估多个评分指标以确定模型性能的最佳参数。即,说:

为了最大化 F1,我应该使用这些参数。为了最大限度地提高精度,我 应该使用这些参数。

我正在处理来自this sklearn page 的以下示例

import numpy as np

from sklearn.datasets import make_hastie_10_2
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import make_scorer
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier

X, y = make_hastie_10_2(n_samples=5000, random_state=42)


scoring = {'PRECISION': 'precision', 'F1': 'f1'}

gs = GridSearchCV(DecisionTreeClassifier(random_state=42),
                  param_grid={'min_samples_split': range(2, 403, 10)},
                  scoring=scoring, refit='F1', return_train_score=True)
gs.fit(X, y)
best_params = gs.best_params_
best_estimator = gs.best_estimator_

print(best_params)
print(best_estimator)

产量:

{'min_samples_split': 62}
DecisionTreeClassifier(min_samples_split=62, random_state=42)

但是,我要寻找的是为每个指标找到这些结果,所以在这种情况下,对于 F1precision

如何在GridSearchCV 中为每种类型的评分指标获取最佳参数?

注意 - 我认为这与我对 refit='F1' 的使用有关,但不确定如何在其中使用多个指标?

【问题讨论】:

    标签: python machine-learning scikit-learn cross-validation grid-search


    【解决方案1】:

    为此,您必须深入了解整个网格搜索 CV 程序的详细结果;幸运的是,这些详细结果在GridSearchCV 对象(docs)的cv_results_ 属性中返回。

    我已按原样重新运行您的代码,但我不会在此处重新键入它;可以说,尽管明确设置了随机数生成器的种子,但我得到了不同的最终结果(我猜是由于版本不同):

    {'min_samples_split': 322}
    DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
                           max_depth=None, max_features=None, max_leaf_nodes=None,
                           min_impurity_decrease=0.0, min_impurity_split=None,
                           min_samples_leaf=1, min_samples_split=322,
                           min_weight_fraction_leaf=0.0, presort='deprecated',
                           random_state=42, splitter='best')
    

    但这对于手头的问题并不重要。

    使用返回的 cv_results_ 字典的最简单方法是将其转换为 pandas 数据框:

    import pandas as pd
    cv_results = pd.DataFrame.from_dict(gs.cv_results_)
    

    不过,由于它包含太多信息(列),我将在此处进一步简化以演示问题(您可以自己更全面地探索它):

    df = cv_results[['params', 'mean_test_PRECISION', 'rank_test_PRECISION', 'mean_test_F1', 'rank_test_F1']]
    
    pd.set_option("display.max_rows", None, "display.max_columns", None)
    pd.set_option('expand_frame_repr', False)
    print(df)
    

    结果:

                            params  mean_test_PRECISION  rank_test_PRECISION  mean_test_F1  rank_test_F1
    0     {'min_samples_split': 2}             0.771782                    1      0.763041            41
    1    {'min_samples_split': 12}             0.768040                    2      0.767331            38
    2    {'min_samples_split': 22}             0.767196                    3      0.776677            29
    3    {'min_samples_split': 32}             0.760282                    4      0.773634            32
    4    {'min_samples_split': 42}             0.754572                    8      0.777967            26
    5    {'min_samples_split': 52}             0.754034                    9      0.777550            27
    6    {'min_samples_split': 62}             0.758131                    5      0.773348            33
    7    {'min_samples_split': 72}             0.756021                    6      0.774301            30
    8    {'min_samples_split': 82}             0.755612                    7      0.768065            37
    9    {'min_samples_split': 92}             0.750527                   10      0.771023            34
    10  {'min_samples_split': 102}             0.741016                   11      0.769896            35
    11  {'min_samples_split': 112}             0.740965                   12      0.765353            39
    12  {'min_samples_split': 122}             0.731790                   13      0.763620            40
    13  {'min_samples_split': 132}             0.723085                   14      0.768605            36
    14  {'min_samples_split': 142}             0.713345                   15      0.774117            31
    15  {'min_samples_split': 152}             0.712958                   16      0.776721            28
    16  {'min_samples_split': 162}             0.709804                   17      0.778287            24
    17  {'min_samples_split': 172}             0.707080                   18      0.778528            22
    18  {'min_samples_split': 182}             0.702621                   19      0.778516            23
    19  {'min_samples_split': 192}             0.697630                   20      0.778103            25
    20  {'min_samples_split': 202}             0.693011                   21      0.781047            10
    21  {'min_samples_split': 212}             0.693011                   21      0.781047            10
    22  {'min_samples_split': 222}             0.693011                   21      0.781047            10
    23  {'min_samples_split': 232}             0.692810                   24      0.779705            13
    24  {'min_samples_split': 242}             0.692810                   24      0.779705            13
    25  {'min_samples_split': 252}             0.692810                   24      0.779705            13
    26  {'min_samples_split': 262}             0.692810                   24      0.779705            13
    27  {'min_samples_split': 272}             0.692810                   24      0.779705            13
    28  {'min_samples_split': 282}             0.692810                   24      0.779705            13
    29  {'min_samples_split': 292}             0.692810                   24      0.779705            13
    30  {'min_samples_split': 302}             0.692810                   24      0.779705            13
    31  {'min_samples_split': 312}             0.692810                   24      0.779705            13
    32  {'min_samples_split': 322}             0.688417                   33      0.782772             1
    33  {'min_samples_split': 332}             0.688417                   33      0.782772             1
    34  {'min_samples_split': 342}             0.688417                   33      0.782772             1
    35  {'min_samples_split': 352}             0.688417                   33      0.782772             1
    36  {'min_samples_split': 362}             0.688417                   33      0.782772             1
    37  {'min_samples_split': 372}             0.688417                   33      0.782772             1
    38  {'min_samples_split': 382}             0.688417                   33      0.782772             1
    39  {'min_samples_split': 392}             0.688417                   33      0.782772             1
    40  {'min_samples_split': 402}             0.688417                   33      0.782772             1
    

    列的名称应该是不言自明的;它们包括尝试的参数、使用的每个指标的分数以及相应的排名(1 表示最佳)。例如,您可以立即看到,尽管 'min_samples_split': 322 确实给出了最好的 F1 分数,但它不是唯一这样做的参数设置,还有更多设置也在结果中给出最好的 F1 分数和 1rank_test_F1

    从这一点来说,获得你想要的信息是微不足道的;例如,以下是您的两个指标中每一个指标的最佳模型:

    print(df.loc[df['rank_test_PRECISION']==1]) # best precision
    # result:
                         params  mean_test_PRECISION  rank_test_PRECISION  mean_test_F1  rank_test_F1
    0  {'min_samples_split': 2}             0.771782                    1      0.763041            41
    
    print(df.loc[df['rank_test_F1']==1]) # best F1
    # result:
                            params  mean_test_PRECISION  rank_test_PRECISION  mean_test_F1  rank_test_F1
    32  {'min_samples_split': 322}             0.688417                   33      0.782772             1
    33  {'min_samples_split': 332}             0.688417                   33      0.782772             1
    34  {'min_samples_split': 342}             0.688417                   33      0.782772             1
    35  {'min_samples_split': 352}             0.688417                   33      0.782772             1
    36  {'min_samples_split': 362}             0.688417                   33      0.782772             1
    37  {'min_samples_split': 372}             0.688417                   33      0.782772             1
    38  {'min_samples_split': 382}             0.688417                   33      0.782772             1
    39  {'min_samples_split': 392}             0.688417                   33      0.782772             1
    40  {'min_samples_split': 402}             0.688417                   33      0.782772             1
    

    【讨论】:

    • 只是为了便于理解...这表明min_samples_split: 2 是最大化精度的最佳超参数,即使使用refit=F1?
    • @wundermahn 完全正确;通过df,您可以轻松确认0.771782 的相应精度值确实是最大值。您在refit 中指定的内容决定了该过程将返回为best_paramsbest_estimator 的内容(这就是为什么在这里您获得最大化F1 的参数,而不是精度),因为很明显您不能优化超过一个指标同时
    猜你喜欢
    • 2020-03-07
    • 2019-07-24
    • 2021-10-27
    • 1970-01-01
    • 2019-04-07
    • 2023-02-12
    • 2020-10-10
    • 1970-01-01
    • 2017-05-19
    相关资源
    最近更新 更多