【发布时间】:2023-03-17 15:28:02
【问题描述】:
我正在研究 PySpark 二进制分类管道,我想在其中使用过采样阶段执行交叉验证(我的数据集不平衡)。问题是过采样阶段也在测试数据集上执行。
管道:
pipeline=Pipeline(stages=[cast_and_fill_na, smote, vec_assembler, rf])
smote 是我在转换测试数据集时要跳过的阶段。
我查看了 spark 文档和源代码,无法跳过 PipelineModel 中的某个阶段。我的解决方案是覆盖原始类的_transform 方法,以跳过过采样阶段。
在我的源代码中安装管道时,这可以正常工作。我用这个:
pipeline_model.__class__ = CustomPipelineModel
CustomPipelineModel 是一个继承自pyspark.ml.PipelineModel 并覆盖_transform 方法的类。
但是由于 CrossValidator 使用的是 PipelineModel 类的原始实现,所以我不能使用我的自定义方法。
evaluator = BinaryClassificationEvaluator(labelCol=target)
crossval = CrossValidator(estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=evaluator,
numFolds=10,
parallelism=1)
cvModel = crossval.fit(train_set)
使用 Cross Validator 时跳过过采样阶段的最佳方法是什么?
我开始研究pyspark.ml.tuning.CrossValidator 的_fit 方法的源代码,考虑也覆盖它...第二种解决方案是对训练数据集执行过采样,但这会在模型中引入偏差交叉验证过程。
【问题讨论】:
标签: python pyspark cross-validation oversampling smote