【问题标题】:Function for cross validation and oversampling (SMOTE)交叉验证和过采样功能 (SMOTE)
【发布时间】:2019-10-02 14:30:39
【问题描述】:

我写了下面的代码。 X 是形状为(1000,5) 的数据框,y 是形状为(1000,1) 的数据框。 y是要预测的目标数据,是不平衡的。我想应用交叉验证和 SMOTE。

def Learning(n, est, X, y):
    s_k_fold = StratifiedKFold(n_splits = n)
    acc_scores = []
    rec_scores = []
    f1_scores = []

    for train_index, test_index in s_k_fold.split(X, y): 
        X_train = X[train_index]
        y_train = y[train_index]    

        sm = SMOTE(random_state=42)
        X_resampled, y_resampled = sm.fit_resample(X_train, y_train)

        X_test = X[test_index]
        y_test = y[test_index]

        est.fit(X_resampled, y_resampled)
        y_pred = est.predict(X_test)
        acc_scores.append(accuracy_score(y_test, y_pred))
        rec_scores.append(recall_score(y_test, y_pred))
        f1_scores.append(f1_score(y_test, y_pred)) 

    print('Accuracy:',np.mean(acc_scores))
    print('Recall:',np.mean(rec_scores))
    print('F1:',np.mean(f1_scores)) 

Learning(3, SGDClassifier(), X_train_s_pca, y_train)

当我运行代码时,我收到以下错误:

[Int64Index([ 4231, 4235, 4246, 4250, 4255, 4295, 4317, 4344, 4381,\n 4387,\n ...\n 13122, 13123, 13124, 13125, 13126, 13127, 13128, 13129, 13130,\n
13131],\n dtype='int64', length=8754)] 在[列]"

感谢帮助使其运行。

【问题讨论】:

    标签: python cross-validation oversampling


    【解决方案1】:

    如果您仔细观察错误堆栈跟踪(这很重要,但您没有包含),您应该会看到错误来自这些行(并且将来自其他类似的行):

    X_train = X[train_index]
    

    这种选择行的方式只适用于 Numpy 数组。由于您使用的是 Pandas DataFrame,因此您应该使用loc

    X_train = X.loc[train_index]
    

    或者,您可以使用 values 将 DataFrame 转换为 Numpy 数组(以尽量减少代码更改):

    Learning(3, SGDClassifier(), X_train_s_pca.values, y_train.values)
    

    【讨论】:

      猜你喜欢
      • 2015-10-29
      • 2019-07-27
      • 2019-10-27
      • 2018-06-30
      • 2020-08-29
      • 2018-09-05
      • 1970-01-01
      • 2023-03-17
      • 2019-05-24
      相关资源
      最近更新 更多