【问题标题】:sklearn grid search f1_score does not match f1_score functionsklearn 网格搜索 f1_score 与 f1_score 函数不匹配
【发布时间】:2017-04-28 01:51:11
【问题描述】:

我一直在尝试使用 sklearn 网格搜索和管道功能,并注意到返回的 f1_score 与我使用硬编码参数生成的 f1_score 不匹配。寻求帮助以了解为什么会这样。

数据背景:两列.csv文件

客户评论(字符串),类别标签(字符串)

使用开箱即用的 sklearn 词袋方法,无需对文本进行预处理,仅使用 countVectorizer。

硬编码模型...

get .csv data into dataFrame
data_file = 'comment_data_basic.csv'
data = pd.read_csv(data_file,header=0,quoting=3)

#remove data without 'web issue' or 'product related' tag
data = data.drop(data[(data.tag != 'WEB ISSUES') & (data.tag != 'PRODUCT RELATED')].index)

#split dataFrame into two series
comment_data = data['comment']
tag_data = data['tag']

#split data into test and train samples
comment_train, comment_test, tag_train, tag_test = train_test_split(
    comment_data, tag_data, test_size=0.33)

#build count vectorizer
vectorizer = CountVectorizer(min_df=.002,analyzer='word',stop_words='english',strip_accents='unicode')
vectorizer.fit(comment_data)

#vectorize features and convert to array
comment_train_features = vectorizer.transform(comment_train).toarray()
comment_test_features = vectorizer.transform(comment_test).toarray()

#train LinearSVM Model
lin_svm = LinearSVC()
lin_svm = lin_svm.fit(comment_train_features,tag_train)

#make predictions
lin_svm_predicted_tags = lin_svm.predict(comment_test_features)

#score models
lin_svm_score = round(f1_score(tag_test,lin_svm_predicted_tags,average='macro'),3)
lin_svm_accur = round(accuracy_score(tag_test,lin_svm_predicted_tags),3)
lin_svm_prec = round(precision_score(tag_test,lin_svm_predicted_tags,average='macro'),3)
lin_svm_recall = round(recall_score(tag_test,lin_svm_predicted_tags,average='macro'),3)

#write out scores
print('Model    f1Score   Accuracy   Precision   Recall')
print('------   -------   --------   ---------   ------')
print('LinSVM   {f1:.3f}     {ac:.3f}      {pr:.3f}       {re:.3f}  '.format(f1=lin_svm_score,ac=lin_svm_accur,pr=lin_svm_prec,re=lin_svm_recall))

f1_score 输出一般在 0.86 左右(取决于随机种子值)

现在,如果我基本上用网格搜索和管道重建相同的输出......

#get .csv data into dataFrame
data_file = 'comment_data_basic.csv'
data = pd.read_csv(data_file,header=0,quoting=3)

#remove data without 'web issue' or 'product related' tag
data = data.drop(data[(data.tag != 'WEB ISSUES') & (data.tag != 'PRODUCT RELATED')].index)

#build processing pipeline
pipeline = Pipeline([
    ('vect', CountVectorizer()),
    ('clf', LinearSVC()),])

#define parameters to be used in gridsearch
parameters = {
    #'vect__min_df': (.001,.002,.003,.004,.005),
    'vect__analyzer': ('word',),
    'vect__stop_words': ('english', None),
    'vect__strip_accents': ('unicode',),
    #'clf__C': (1,10,100,1000),
}

if __name__ == '__main__':

    grid_search = GridSearchCV(pipeline,parameters,scoring='f1_macro',n_jobs=1)

    grid_search.fit(data['comment'],data['tag'])

    print("Best score: %0.3f" % grid_search.best_score_)
    print("Best parameters set:")
    best_params = grid_search.best_estimator_.get_params()
    for param_name in sorted(parameters.keys()):
        print("\t%s: %r" % (param_name, best_params[param_name]))

返回的f1_score接近0.73,所有模型参数相同。我的理解是网格搜索在内部应用了交叉验证方法,但我的猜测是,与在原始代码中使用 test_train_split 相比,它使用的任何方法都不同。然而,从 0.83 -> 0.73 的下降对我来说感觉很大,我希望对自己的结果充满信心。

任何见解将不胜感激。

【问题讨论】:

标签: python scikit-learn cross-validation


【解决方案1】:

在您提供的代码中,您没有设置 LinearSVC 模型的 random_state 参数,因此即使使用相同的超参数,您也不太可能从您的 GridSearchCV 复制最佳估计器的精确副本。然而,这比实际发生的事情更微不足道。

在您的案例中,GridSearch 正在使用 3 折数据进行交叉验证。您看到的 best_score 是在您的测试数据上得分时在所有折叠中平均表现最佳的模型得分,它可能不是在您的训练/测试拆分中得分最高的估计器。考虑到您提供 GridSearch 的拆分,不同的估算器可能会得分更高,但如果您要生成一些不同的拆分并在每个测试集上对估算器进行评分,平均而言,best_estimator 将出现最佳。这个想法是,通过交叉验证,您将选择一个对不一定在单个训练/测试拆分中表示的数据变化更具弹性的估计器。因此,您进行的拆分越多,您的模型在新的未见数据上的表现就越好。在这种情况下,更好可能并不意味着它每次都会产生更准确的结果,但考虑到现有数据中存在的变化,模型将在包含这些变化方面做得更好,并且从长远来看平均会产生更准确的结果,因为只要新的看不见的数据在训练数据中看到的范围内。

如果您想了解有关估算器在拆分中的表现的更多信息,请查看grid_search.cv_results_,以更好地了解在整个过程中逐步发生的情况。

【讨论】:

    猜你喜欢
    • 2023-04-01
    • 2020-05-22
    • 1970-01-01
    • 2016-01-19
    • 1970-01-01
    • 2022-12-19
    • 2016-02-23
    • 2017-08-29
    • 1970-01-01
    相关资源
    最近更新 更多