【发布时间】:2020-09-14 18:31:25
【问题描述】:
我正在将 Spark 从版本 2.3.1 升级到 2.4.5。我正在使用 Dataproc 映像 1.4.27-debian9 在 Google Cloud Platform 的 Dataproc 上使用 Spark 2.4.5 重新训练模型。当我使用 Spark 2.4.5 在本地机器上加载 Dataproc 生成的模型来验证模型时。不幸的是,我遇到了以下异常:
20/05/27 08:36:35 INFO HadoopRDD: Input split: file:/Users/.../target/classes/model.ml/stages/1_gbtc_961a6ef213b2/metadata/part-00000:0+657
20/05/27 08:36:35 INFO HadoopRDD: Input split: file:/Users/.../target/classes/model.ml/stages/1_gbtc_961a6ef213b2/metadata/part-00000:0+657
Exception in thread "main" java.lang.IllegalArgumentException: gbtc_961a6ef213b2 parameter impurity given invalid value variance.
加载模型的代码非常简单:
import org.apache.spark.ml.PipelineModel
object ModelLoad {
def main(args: Array[String]): Unit = {
val modelInputPath = getClass.getResource("/model.ml").getPath
val model = PipelineModel.load(modelInputPath)
}
}
我按照堆栈跟踪检查了1_gbtc_961a6ef213b2/metadata/part-00000 模型元数据文件,发现以下内容:
{
"class": "org.apache.spark.ml.classification.GBTClassificationModel",
"timestamp": 1590593177604,
"sparkVersion": "2.4.5",
"uid": "gbtc_961a6ef213b2",
"paramMap": {
"maxIter": 50
},
"defaultParamMap": {
...
"impurity": "variance",
...
},
"numFeatures": 1,
"numTrees": 50
}
杂质设置为variance,但我的本地spark 2.4.5 预计它是gini。为了进行完整性检查,我在本地 spark 2.4.5 上重新训练了模型。模型元数据文件中的impurity 设置为gini。
所以,我检查了 GBT Javadoc 中的 spark 2.4.5 setImpurity method。它说The impurity setting is ignored for GBT models. Individual trees are built using impurity "Variance."。 Dataproc 使用的 spark 2.4.5 似乎与 Apache Spark 文档一致。但是,我从 Maven Central 使用的 Spark 2.4.5 将 impurity 值设置为 gini。
有人知道为什么 Dataproc 中的 Spark 2.4.5 和 Maven Central 之间会出现这种不一致吗?
我创建了一个简单的训练代码来在本地重现结果:
import java.nio.file.Paths
import org.apache.spark.ml.classification.GBTClassifier
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.{DataFrame, SparkSession}
object SimpleModelTraining {
def main(args: Array[String]) {
val currentRelativePath = Paths.get("")
val save_file_location = currentRelativePath.toAbsolutePath.toString
val spark = SparkSession.builder()
.config("spark.driver.host", "127.0.0.1")
.master("local")
.appName("spark-test")
.getOrCreate()
val df: DataFrame = spark.createDataFrame(Seq(
(0, 0),
(1, 0),
(1, 0),
(0, 1),
(0, 1),
(0, 1),
(0, 2),
(0, 2),
(0, 2),
(0, 3),
(0, 3),
(0, 3),
(1, 4),
(1, 4),
(1, 4)
)).toDF("label", "category")
val pipeline: Pipeline = new Pipeline().setStages(Array(
new VectorAssembler().setInputCols(Array("category")).setOutputCol("features"),
new GBTClassifier().setMaxIter(30)
))
val pipelineModel: PipelineModel = pipeline.fit(df)
pipelineModel.write.overwrite().save(s"$save_file_location/test_model.ml")
}
}
谢谢!
【问题讨论】:
标签: scala apache-spark google-cloud-dataproc