【问题标题】:How to simplify my data preprocessing with scikit learn pipelines如何使用 scikit learn 管道简化我的数据预处理
【发布时间】:2019-03-29 09:59:55
【问题描述】:

我有 2 个 dfs。 df1 是猫的例子,df2 是狗的例子。

我必须通过调用不同的函数对这些 dfs 进行一些预处理。我想使用 scikit 学习管道。

其中一个函数是一个特殊的编码器函数,它将查看 df 中的列并返回一个特殊值。我在 我看到在 scikit learn 中使用这样的类中重写了该函数:

class Encoder(BaseEstimator, TransformerMixin):

    def __init__(self):
        self.values = []
        super().__init__()

    def fit(self, X, y=None):
        return self

    def encode(self,row):
        result = []
        for base in row:
            result.append(bases[base])

        self.values.append(result)

    def transform(self, X):
        assert isinstance(X, pd.DataFrame)
        X["seq_new"].apply(self.encode)

        return self.values

所以现在我会得到 2 个列表:

encode = Encoder()
X1 = encode.transform(df1)
X2 = encode.transform(df2)

下一步是:

features = np.concatenate((X1, X1), axis=0)

下一步构建标签:

Y_dog = [[1]] * len(X1)
Y_cat = [[0]] * len(X2)
labels = np.concatenate((Y_dog, Y_cat), axis=0)

以及其他一些操作,然后我将执行 model_selection.train_test_split() 将数据拆分为训练和测试。

如何在 scikit 管道中调用所有这些函数?我发现的示例从已经完成训练/测试拆分的地方开始。

【问题讨论】:

  • 出于好奇,为什么在将两个 DF 连接在一起之后又调用 transform 两次?如果你想使用管道,通常在数据工程完成后使用它,即在训练测试拆分之后。原因是这样的:如果您 fit() 完整数据集上的模型或转换器,它会在训练集和测试集之间创建数据泄漏到模型中
  • @G.Anderson 好的,我明白了,可能这就是我没有找到示例的原因。谢谢

标签: python machine-learning scikit-learn


【解决方案1】:

sklearn.pipeline.Pipeline 的问题在于每一步都需要实现fittransform。因此,例如,如果您知道您将始终需要执行连接步骤,并且您真的很想将其放入 Pipeline (我不会,但这只是我的拙见),您需要使用适当的fittransform 方法创建Concatenator class

类似这样的:

class Encoder(object):
    def fit(self, X, *args, **kwargs):
        return self
    def transform(self, X):
        return X*2

class Concatenator(object):
    def fit(self, X, *args, **kwargs):
        return self
    def transform(self, Xs):
        return np.concatenate(Xs, axis=0)

class MultiEncoder(Encoder):
    def transform(self, Xs):
        return list(map(super().transform, Xs))

pipe = sklearn.pipeline.Pipeline((
    ("encoder", MultiEncoder()),
    ("concatenator", Concatenator())
))

pipe.fit_transform((
    pd.DataFrame([[1,2],[3,4]]), 
    pd.DataFrame([[5,6],[7,8]])
))

【讨论】:

    猜你喜欢
    • 2018-07-07
    • 2013-04-14
    • 2015-08-14
    • 1970-01-01
    • 2021-06-30
    • 1970-01-01
    • 2021-03-17
    • 2023-04-05
    • 2020-11-10
    相关资源
    最近更新 更多