【问题标题】:How to perform SMOTE with cross validation in sklearn in python如何在 python 的 sklearn 中使用交叉验证执行 SMOTE
【发布时间】:2019-08-30 15:27:09
【问题描述】:

我有一个高度不平衡的数据集,并希望执行 SMOTE 以平衡数据集并执行交叉验证以测量准确性。但是,大多数现有教程仅使用单个 trainingtesting 迭代来执行 SMOTE。

因此,我想知道使用交叉验证执行 SMOTE 的正确程序。

我当前的代码如下。但是,如上所述,它只使用单次迭代。

from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
sm = SMOTE(random_state=2)
X_train_res, y_train_res = sm.fit_sample(X_train, y_train.ravel())
clf_rf = RandomForestClassifier(n_estimators=25, random_state=12)
clf_rf.fit(x_train_res, y_train_res)

如果需要,我很乐意提供更多详细信息。

【问题讨论】:

    标签: python machine-learning scikit-learn classification cross-validation


    【解决方案1】:

    我认为您也可以使用不平衡学习库中的管道来解决此问题。

    我在名为 Machine Learning Mastery https://machinelearningmastery.com/smote-oversampling-for-imbalanced-classification/ 的博客中看到了这个解决方案

    这个想法是使用 imblearn 的管道来进行交叉验证。请让我知道这是否有效。下面的例子是一个决策树,但逻辑是一样的。

    #decision tree evaluated on imbalanced dataset with SMOTE oversampling
    from numpy import mean
    from sklearn.datasets import make_classification
    from sklearn.model_selection import cross_val_score
    from sklearn.model_selection import RepeatedStratifiedKFold
    from sklearn.tree import DecisionTreeClassifier
    from imblearn.pipeline import Pipeline
    from imblearn.over_sampling import SMOTE
    # define dataset
    X, y = make_classification(n_samples=10000, n_features=2, n_redundant=0,
        n_clusters_per_class=1, weights=[0.99], flip_y=0, random_state=1)
    # define pipeline
    steps = [('over', SMOTE()), ('model', DecisionTreeClassifier())]
    pipeline = Pipeline(steps=steps)
    # evaluate pipeline
    cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
    scores = cross_val_score(pipeline, X, y, scoring='roc_auc', cv=cv, n_jobs=-1)
    score =  mean(scores))
    

    【讨论】:

      【解决方案2】:

      您需要在每次折叠执行 SMOTE。因此,您需要避免使用train_test_split,而使用KFold

      from sklearn.model_selection import KFold
      from imblearn.over_sampling import SMOTE
      from sklearn.metrics import f1_score
      
      kf = KFold(n_splits=5)
      
      for fold, (train_index, test_index) in enumerate(kf.split(X), 1):
          X_train = X[train_index]
          y_train = y[train_index]  # Based on your code, you might need a ravel call here, but I would look into how you're generating your y
          X_test = X[test_index]
          y_test = y[test_index]  # See comment on ravel and  y_train
          sm = SMOTE()
          X_train_oversampled, y_train_oversampled = sm.fit_sample(X_train, y_train)
          model = ...  # Choose a model here
          model.fit(X_train_oversampled, y_train_oversampled )  
          y_pred = model.predict(X_test)
          print(f'For fold {fold}:')
          print(f'Accuracy: {model.score(X_test, y_test)}')
          print(f'f-score: {f1_score(y_test, y_pred)}')
      

      例如,您还可以将分数附加到外部定义的list

      【讨论】:

      • 请注意:您可能希望使用 StratifiedKFold 代替,就像在另一个答案中一样,因为您可能有一个不平衡的类问题。
      • 非常感谢。我也有一个 y 值。在那种情况下,我该如何更改这个in enumerate(kf.split(X), 1):
      • @Emi 您不需要修改它。 kf.split 所做的只是取 Xsize (它有多少行)来确定如何为每个折叠生成索引。由于您的y 应该与X 大小相同,因此您无需提供它。也就是说,你可以kf.split(X, y),它会产生同样的效果。
      • @gmds 一个小问题:你为什么不对过采样数据``` X_train_oversampled ``` 和y_train_oversampled 拟合模型,而你宁愿使用model.fit(X_train, y_train)
      • @Hiyam 这实际上是我的错误,谢谢!将编辑。
      【解决方案3】:
      from sklearn.model_selection import StratifiedKFold
      from imblearn.over_sampling import SMOTE
      
      cv = StratifiedKFold(n_splits=5)
      for train_idx, test_idx, in cv.split(X, y):
          X_train, y_train = X[train_idx], y[train_idx]
          X_test, y_test = X[test_idx], y[test_idx]
          X_train, y_train = SMOTE().fit_sample(X_train, y_train)
          ....
      

      【讨论】:

        猜你喜欢
        • 2016-04-13
        • 2018-08-16
        • 2017-04-10
        • 2019-08-31
        • 2015-06-11
        • 2019-07-27
        • 2020-07-10
        • 2018-06-30
        • 2013-12-13
        相关资源
        最近更新 更多