【发布时间】:2018-08-01 12:12:38
【问题描述】:
我正在运行以下代码:
def calcClusteringScores(data: RDD[Vector], k: Int) : Double = {
val model = KMeans.train(data=data, k, maxIterations = 1)
data.map(datum => distanceToCentroid(datum, model)).mean()
}
KMeans.train 返回一个KMeansModel(参见:here),它实现了可序列化并且应该是可序列化的。
但是,当我运行 data.map 函数时,我收到一个 object not serializable 异常抱怨模型。有没有办法解决这个问题,我错过了?
更新 1
下面是distanceToCentroid方法,调用距离。计算2个向量之间的欧几里得距离
def distanceToCentroid(datum: Vector, model: KMeansModel) : Double ={
val cluster = model.predict(datum)
val clusterCenter = model.clusterCenters(cluster)
distance(datum, clusterCenter)
}
def distance(a: Vector, b: Vector) : Double ={
val a_arr = a.toArray
val b_arr = b.toArray
val pairs = a_arr.zip(b_arr)
val sumOfSquares = pairs.map(pair => pair._1 - pair._2)
.map(diff => diff * diff)
.sum
sqrt(sumOfSquares)
}
更新 2
我通过将方法体从函数移动到主方法来解决序列化问题。我不再收到序列化错误,但我不知道为什么。有人有什么想法吗?
def testSerialiseModel() ={
val sparkConf = new SparkConf().setAppName("ModelTest").setMaster("local")
val sc = new SparkContext(sparkConf)
val sparkSession = SparkSession.builder().getOrCreate()
val data = sc.parallelize(Array(
Vectors.dense(Array(1.0, 2.0, 3.0)),
Vectors.dense(Array(1.0, 1.8, 2.3)),
Vectors.dense(Array(2.0, 1.5, 3.0))
))
val model = KMeans.train(data=data, 2, maxIterations = 1)
val score = data.map{datum =>
val cluster = model.predict(datum)
val clusterCenter = model.clusterCenters(cluster)
val pairs = datum.toArray.zip(clusterCenter.toArray)
val sumOfSquares = pairs.map(pair => pair._1 - pair._2)
.map(diff => diff * diff)
.sum
sqrt(sumOfSquares)
}.mean()
println(s"clustering score: ${score}")
}
【问题讨论】:
-
什么是
distanceToCentroid?因此我无法重现您的错误,并且在KMeans的文档中找不到它。仅供参考,当我在您的定义中没有data.map部分的情况下运行您的代码时,我不会收到任何错误,因此它不会是与KMeans模型相关的问题。也许您使用 RDD 定义了方法distanceToCentroid,并且该方法未序列化。可能这就是您收到错误的原因 -
在使用
clusterCenters()调用distanceToCentroid之前获取质心作为向量怎么样?或者,如果您想要整个数据集的成本,您可以使用computeCost方法。 -
@user322778 谢谢。我已更新问题以显示 distanceToCentroid 方法。
-
@Shaido 谢谢,model.clusterCenters(cluster) 已经返回一个向量
-
你在哪里运行你的代码?我运行了您的代码,包括
distanceToCentroidpart,没有任何错误(我使用的是 Cloudera 虚拟机 5.8.0)
标签: scala apache-spark apache-spark-mllib