【问题标题】:Probability of predictions using Spark LogisticRegressionWithLBFGS for multiclass classification使用 Spark LogisticRegressionWithLBFGS 进行多类分类的预测概率
【发布时间】:2016-07-09 04:10:05
【问题描述】:

我正在使用LogisticRegressionWithLBFGS() 来训练具有多个类的模型。

mllib 中的文档写道,clearThreshold() 只有在分类为二进制时才能使用。有没有办法使用类似的东西进行多类分类,以便在模型的给定输入中输出每个类的概率?

【问题讨论】:

    标签: apache-spark pyspark logistic-regression apache-spark-mllib


    【解决方案1】:

    有两种方法可以做到这一点。一种是在LogisticRegression.scala中创建一个承担predictPoint职责的方法

    object ClassificationUtility {
      def predictPoint(dataMatrix: Vector, model: LogisticRegressionModel):
        (Double, Array[Double]) = {
        require(dataMatrix.size == model.numFeatures)
        val dataWithBiasSize: Int = model.weights.size / (model.numClasses - 1)
        val weightsArray: Array[Double] = model.weights match {
          case dv: DenseVector => dv.values
          case _ =>
            throw new IllegalArgumentException(s"weights only supports dense vector but got type ${model.weights.getClass}.")
        }
        var bestClass = 0
        var maxMargin = 0.0
        val withBias = dataMatrix.size + 1 == dataWithBiasSize
        val classProbabilities: Array[Double] = new Array[Double (model.numClasses)
        (0 until model.numClasses - 1).foreach { i =>
          var margin = 0.0
          dataMatrix.foreachActive { (index, value) =>
          if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index)
          }
          // Intercept is required to be added into margin.
          if (withBias) {
            margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size)
          }
          if (margin > maxMargin) {
            maxMargin = margin
            bestClass = i + 1
          }
          classProbabilities(i+1) = 1.0 / (1.0 + Math.exp(-margin))
        }
        return (bestClass.toDouble, classProbabilities)
      }
    }
    

    请注意,它与原始方法仅略有不同,它只是将逻辑计算为输入特征的函数。它还定义了一些最初是私有的并包含在此方法之外的 val 和 var。最终,它将分数索引到一个数组中,并将其与最佳答案一起返回。我这样称呼我的方法:

    // Compute raw scores on the test set.
    val predictionAndLabelsAndProbabilities = test
      .map { case LabeledPoint(label, features) =>
    val (prediction, probabilities) = ClassificationUtility
      .predictPoint(features, model)
    (prediction, label, probabilities)}
    

    但是:

    Spark 贡献者似乎不鼓励使用 MLlib 来支持 ML。 ML 逻辑回归 API 目前不支持多类分类。我现在使用OneVsRest,它充当一个与所有分类的包装器。您可以通过迭代模型获得原始分数:

    val lr = new LogisticRegression().setFitIntercept(true)
    val ovr = new OneVsRest()
    ovr.setClassifier(lr)
    val ovrModel = ovr.fit(training)
    ovrModel.models.zipWithIndex.foreach {
      case (model: LogisticRegressionModel, i: Int) =>
        model.save(s"model-${model.uid}-$i")
    }
    
    val model0 = LogisticRegressionModel.load("model-logreg_457c82141c06-0")
    val model1 = LogisticRegressionModel.load("model-logreg_457c82141c06-1")
    val model2 = LogisticRegressionModel.load("model-logreg_457c82141c06-2")
    

    现在您已经有了各个模型,您可以通过计算 rawPrediction 的 sigmoid 来获得概率

    def sigmoid(x: Double): Double = {
      1.0 / (1.0 + Math.exp(-x))
    }
    
    val newPredictionAndLabels0 = model0.transform(newRescaledData)
      .select("prediction", "rawPrediction")
      .map(row => (row.getDouble(0),
        sigmoid(row.getAs[org.apache.spark.mllib.linalg.DenseVector](1).values(1)) ))
    newPredictionAndLabels0.foreach(println)
    
    val newPredictionAndLabels1 = model1.transform(newRescaledData)
      .select("prediction", "rawPrediction")
      .map(row => (row.getDouble(0),
        sigmoid(row.getAs[org.apache.spark.mllib.linalg.DenseVector](1).values(1)) ))
    newPredictionAndLabels1.foreach(println)
    
    val newPredictionAndLabels2 = model2.transform(newRescaledData)
      .select("prediction", "rawPrediction")
      .map(row => (row.getDouble(0),
        sigmoid(row.getAs[org.apache.spark.mllib.linalg.DenseVector](1).values(1)) ))
    newPredictionAndLabels2.foreach(println)
    

    【讨论】:

    • 我尝试了 OnevsRest 的解决方案,但是在使用 .select("prediction","rawPrediction ") 时,我无法访问 rawPrediction 列,在执行 newPredictionAndLabels0.show()..How你得到 rawPrediction 了吗?
    • @outlier 看起来你在“rawPrediction”的末尾有一个额外的空间。尝试删除它,看看是否可以修复它。
    猜你喜欢
    • 2018-10-12
    • 2016-01-05
    • 1970-01-01
    • 2017-06-12
    • 2020-09-10
    • 2019-09-16
    • 2016-02-07
    • 2017-10-11
    • 2017-06-12
    相关资源
    最近更新 更多