【问题标题】:Adjust Intercept of Spark DataFrame API Logistic Regression Model调整 Spark DataFrame API Logistic 回归模型的拦截
【发布时间】:2016-12-06 03:25:16
【问题描述】:

我正在 Spark 中训练逻辑回归。但是,由于我的训练数据的特殊性,我需要在之后手动调整模型,即更改截距。

使用 RDD api 很容易做到这一点 - 只需实例化一个新的 LogisticRegressionModel:

val intercept = model.intercept() + adjustment
val model = new LogisticRegressionModel(model.weights(), intercept)

但是,DataFrame API 中的 LogisticRegressionModel 构造函数是私有的。如何手动调整模型?

【问题讨论】:

    标签: apache-spark apache-spark-mllib


    【解决方案1】:

    今天下午我遇到了同样的问题,我处于测试模式,无论如何都试图让它发生,所以我不在乎它有多脏:从你的模型中获取系数,获取截距,调整它,然后使用他们在 Spark 中使用的 code 手动进行预测(查找 BLAS.dotmarginscore)。在某些时候他们使用BLAS.dot,好吧BLAS 是私有的。再次做同样的事情,检索dot的代码,处理SparseVector/DenseVector就可以了。很脏,但它可以工作。

    【讨论】:

    • 或者你可以分叉 Spark 并添加你自己的复制函数,并带有一个可以改变截距的参数。这也有效。可能更漂亮。
    • 显然,通过反射可以更轻松地完成(仍然很脏):val modelClass = classOf[LogisticRegressionModel]; val const = modelClass.getDeclaredConstructor(classOf[String], classOf[Vector], classOf[Double]); val intercept = trained.intercept - adjustment; val newModel = const.newInstance(trained.uid, trained.coefficients, intercept:java.lang.Double))