【问题标题】:One-hot encoding multiple variables with Spark 2.1.1使用 Spark 2.1.1 对多个变量进行 One-hot 编码
【发布时间】:2020-06-23 05:53:31
【问题描述】:

我需要使用 Spark 2.1.1 并有一个简单的 ML 用例,我在其中拟合逻辑回归以执行基于连续变量和分类变量的分类。

我会自动检测分类变量并在 ML 管道中为它们编制索引。但是,当我尝试对每个变量(下面代码中的 oneHotEncodersStages 值)应用单热编码时,创建管道时会导致错误:

错误:(48, 118) 类型不匹配;找到:数组[java.io.Serializable] 必需: Array[_ <: org.apache.spark.ml.pipelinestage java.io.serializable>: org.apache.spark.ml.PipelineStage,但类 数组在类型 T 中是不变的。您可能希望研究通配符 类型如_ &gt;: org.apache.spark.ml.PipelineStage。 (SLS 3.2.10)
val pipeline = new Pipeline().setStages(stringIndexerStages :+ oneHotEncodersStages :+ 索引器 :+ 汇编器 :+ lr :+ indexToLabel)

我没有找到解决此错误的方法...有什么提示吗?下面是一个最小的工作示例

  import spark.implicits._
  val df = Seq(
    ("automatic","Honda",200,"Cheap"),
    ("semi-automatic","Ford",240,"Expensive")
  ).toDF("cat_type","cat_brand","Speed","label")

  def onlyFeatureCols(c: String): Boolean = !(c matches "id|label") // Function to select only feature columns (omit id and label)
  def isCateg(c: String): Boolean = c.startsWith("cat")
  def categNewCol(c: String): String = if (isCateg(c)) s"idx_${c}" else c
  def isIdx(c: String): Boolean = c.startsWith("idx")
  def oneHotNewCol(c: String): String = if (isIdx(c)) s"vec_${c}" else c

  val featuresNames = df.columns
    .filter(onlyFeatureCols)
    .map(categNewCol)

  val stringIndexerStages = df.columns.filter(isCateg)
    .map(c => new StringIndexer()
      .setInputCol(c)
      .setOutputCol(categNewCol(c))
      .fit(df.select(c))
    )

  val oneHotEncodersStages = df.columns.filter(isIdx)
    .map(c => new OneHotEncoder()
      .setInputCol(c)
      .setOutputCol(oneHotNewCol(c)))

  val indexer = new StringIndexer().setInputCol("label").setOutputCol("labels").fit(df)
  val indexToLabel = new IndexToString().setInputCol("prediction").setOutputCol("predicted_label").setLabels(indexer.labels)
  val assembler = new VectorAssembler().setInputCols(featuresNames).setOutputCol("features")
  val lr = new LogisticRegression().setFeaturesCol("features").setLabelCol("labels")

  val pipeline = new Pipeline().setStages(stringIndexerStages :+ oneHotEncodersStages ++ indexer :+  assembler :+ lr :+ indexToLabel)

【问题讨论】:

    标签: scala apache-spark apache-spark-ml


    【解决方案1】:

    stringIndexerStagesoneHotEncodersStages 是数组。 stringIndexerStages :+ oneHotEncodersStages - 创建新数组,其中第二个数组用作单个元素。使用“++”代替“:+”:

    val pipeline = new Pipeline().setStages(stringIndexerStages ++ oneHotEncodersStages :+ indexer :+  assembler :+ lr :+ indexToLabel)
    

    【讨论】:

      猜你喜欢
      • 2020-03-18
      • 1970-01-01
      • 2018-02-23
      • 2018-05-22
      • 2018-10-31
      • 2019-09-23
      • 2021-04-22
      • 2023-04-08
      • 2019-09-13
      相关资源
      最近更新 更多