【问题标题】:What is the input format of org.apache.spark.ml.classification.LogisticRegression fit()?org.apache.spark.ml.classification.LogisticRegression fit() 的输入格式是什么?
【发布时间】:2016-08-01 12:12:40
【问题描述】:

this 训练 LogisticRegression 模型的示例中,他们使用 RDD[LabeledPoint] 作为 fit() 方法的输入,但他们写了 "// 我们使用 LabeledPoint,这是一个案例类。Spark SQL 可以转换 RDD案例类别 // 进入 SchemaRDD,它使用案例类元数据来推断模式。"

这种转换发生在哪里?当我尝试这段代码时:

val sqlContext = new SQLContext(sc)
import sqlContext._
val model = lr.fit(training);

,如果训练是 RDD[LabeledPoint] 类型,它会给出一个编译错误,说明 fit 需要一个数据框。当我将 RDD 转换为数据框时,我得到了这个异常:

An exception occured while executing the Java class. null: InvocationTargetException: requirement failed: Column features must be of type org.apache.spark.mllib.linalg.VectorUDT@f71b0bce but was actually StructType(StructField(label,DoubleType,false), StructField(features,org.apache.spark.mllib.linalg.VectorUDT@f71b0bce,true))

但这让我很困惑。为什么它会期望一个向量?它还需要标签。所以我想知道正确的格式是什么?

我使用 ML LogisticRegression 而不是 Mllib LogisticRegressionWithLBFGS 的原因是因为我想要一个 elasticNet 实现。

【问题讨论】:

  • 如果您使用的是 spark 2.0,最好在应用 mllib 中的函数之前将所有 RDD 转换为 Dataframes - 但您的错误并非如此,我相信输入需要是 mllib 向量- 你需要做 Vectors.dense(Array[Doubles]) 其中双打是你的数据点。
  • 但是标签呢?您是否假设标签只是向量的第一列?

标签: scala apache-spark


【解决方案1】:

异常表示 DataFrame 需要以下结构:

StructType(StructField(label,DoubleType,false), 
StructField(features,org.apache.spark.mllib.linalg.VectorUDT@f71b0bce,true))

因此,从(标签、特征)元组列表中准备训练数据,如下所示:

val training = sqlContext.createDataFrame(Seq(
  (1.0, Vectors.dense(0.0, 1.1, 0.1)),
  (0.0, Vectors.dense(2.0, 1.0, -1.0)),
  (0.0, Vectors.dense(2.0, 1.3, 1.0)),
  (1.0, Vectors.dense(0.0, 1.2, -0.5))
)).toDF("label", "features")

【讨论】:

  • 例外不是说反了吗?我的数据已经是那种格式并且应该是 Vector 类型的?
猜你喜欢
  • 1970-01-01
  • 2021-12-25
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2012-04-08
  • 2018-09-09
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多