【问题标题】:how to extend (or proxy) a scala class with private constructor如何使用私有构造函数扩展(或代理)scala类
【发布时间】:2016-10-22 19:13:15
【问题描述】:

我正在尝试扩展或代理 org.apache.spark.ml.clustering.KMeans 类,以便授权 K=1。

class K1Means extends Estimator{

    final val kmeans = new KMeans()
    val k = 1

    override def setK(value:Int) {
        if(value >1){
            this.kmeans.setK(value)
        }
    }

    override def fit(dataset: DataFrame): KMeansModel = { 
        if(this.k == 1){
            /* super specific to my case */
            val avg_sample = Vectors.zeros(
                dataset
                .select("scaledFeatures")
                .take(1)(0)(0)  // first row
                .asInstanceOf[DenseVector]  // was of type Any
                .size
            ) // with the scaling the average value of each column is 0
            var centers_local = Array(avg_sample)
            return new KMeansModel(centers_local)
        }
        else{
            return this.kmeans.fit(dataset)
        }
    }
// every method then calls this.kmeans.method()
}

我已经尝试过了,但是 new KMeansModel(centers_local) 没有被授权,因为 KMeansModel 有一个私有构造函数。 这是错误消息:

constructor KMeansModel in class KMeansModel cannot be accessed in class K1Means

我也尝试扩展 KMeansModel,所以我可以创建自己的并返回它:

class K1MeansModel(centers: Array[DenseVector]) extends KMeansModel{}

但它也失败了:constructor KMeansModel in class KMeansModel cannot be accessed in class K1MeansModel

【问题讨论】:

标签: scala inheritance private proxy-classes


【解决方案1】:

这里有几个问题,首先是 KMeansModel 是私有的: https://github.com/apache/spark/blob/4f83ca1059a3b580fca3f006974ff5ac4d5212a1/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala#L102

为什么会出现这个问题?您完全可以按照您建议的方式编写自己的代理,但是为了覆盖“fit”方法,该函数返回的数据类型需要是 KMeansModel 或兼容的(比如说“K1MeansModel”),如下所示:

class K1MeansModel extends KMeansModel{
    // ...
}

class K1Means extends KMeans{

    final val kmeans = new KMeans()
    // ...

    override def fit(dataset: DataFrame): KMeansModel = { 
        if(this.k == 1){
            // ...
            return new K1MeansModel(centers_local)
        }
        else{
            return this.kmeans.fit(dataset)
        }
    }
}

但是,是的,因为 KMeansModel 是私有的,所以这是不可能的。所以你可能会想“为什么不重新实现它呢?”。实际上,您可以从 GitHub 复制并粘贴 KMeansModel 的整个代码。

KMeansModel 的定义如下:

class KMeansModel (
        override val uid: String, 
        private val parentModel: MLlibKMeansModel) 
    extends Model[KMeansModel] with KMeansParams { }

但是是的,因为 KMeansParams 是私有的,所以这是不可能的。所以你可能会想“为什么不重新实现它呢?”。实际上,您可以从 GitHub 复制并粘贴 KMeansParams 的整个代码。

KMeansParams 的定义如下:

trait K1MeansParams 
    extends Params 
        with HasMaxIter 
        with HasFeaturesCol 
        with HasSeed 
        with HasPredictionCol 
        with HasTol { }

但是,是的,因为 HasMaxIter、HasFeaturesCol、HasSeed、HasPredictionCol、HaTol 都是私有的,这是不可能的。 ...你明白了。


TL;DR 是的,您可以在项目中重新实现(复制和粘贴)大量 spark 类,只是为了覆盖 KMeans。我计算了至少 7 个需要复制和粘贴的课程。对我来说这感觉很糟糕。 我建议将代码直接添加到 Apache Spark。 fork Spark GitHub repo,将 K=1 的代码直接添加到 ml.KMeans 类中即可。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2020-10-21
    • 1970-01-01
    • 2012-01-09
    • 2012-11-18
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多