【问题标题】:Oversampling with Cross Validation in PySpark Pipeline在 PySpark 管道中使用交叉验证进行过采样
【发布时间】:2023-03-17 15:28:02
【问题描述】:

我正在研究 PySpark 二进制分类管道,我想在其中使用过采样阶段执行交叉验证(我的数据集不平衡)。问题是过采样阶段也在测试数据集上执行。

管道:

pipeline=Pipeline(stages=[cast_and_fill_na, smote, vec_assembler, rf])

smote 是我在转换测试数据集时要跳过的阶段。

我查看了 spark 文档和源代码,无法跳过 PipelineModel 中的某个阶段。我的解决方案是覆盖原始类的_transform 方法,以跳过过采样阶段。 在我的源代码中安装管道时,这可以正常工作。我用这个:

pipeline_model.__class__ = CustomPipelineModel

CustomPipelineModel 是一个继承自pyspark.ml.PipelineModel 并覆盖_transform 方法的类。

但是由于 CrossValidator 使用的是 PipelineModel 类的原始实现,所以我不能使用我的自定义方法。

evaluator = BinaryClassificationEvaluator(labelCol=target)    
crossval = CrossValidator(estimator=pipeline,
                                      estimatorParamMaps=paramGrid,
                                      evaluator=evaluator,
                                      numFolds=10,
                                      parallelism=1)
cvModel = crossval.fit(train_set)

使用 Cross Validator 时跳过过采样阶段的最佳方法是什么?

我开始研究pyspark.ml.tuning.CrossValidator_fit 方法的源代码,考虑也覆盖它...第二种解决方案是对训练数据集执行过采样,但这会在模型中引入偏差交叉验证过程。

【问题讨论】:

    标签: python pyspark cross-validation oversampling smote


    【解决方案1】:

    我想出了一个解决这个问题的方法。 在我的 SMOTEOversmapler 类(smote 阶段是它的一个实例)中,我添加了一个名为 skip_transform 的属性,在实例化 SMOTEOversmapler 对象时设置为 None。在_transform 方法中,我将此属性设置为True。将跳过对_transform(处于测试阶段)的下一次调用。这是一个代码sn-p。

    def __init__(self, ...):
        self.skip_transfrom = None
    def _transform(self, df):
        if self.skip_transform:
             retrun df
        else:
             #Execute oversampling
             self.skip_transform = True
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-09-24
      • 2015-10-29
      • 2020-03-07
      • 2018-09-05
      • 2016-11-15
      • 2020-10-28
      • 2019-10-02
      • 2019-10-27
      相关资源
      最近更新 更多