【问题标题】:How to load logistic regression model?如何加载逻辑回归模型?
【发布时间】:2018-05-21 07:47:09
【问题描述】:

我想在 Java 中使用 Apache Spark 训练逻辑回归模型。作为第一步,我想只训练一次模型并保存模型参数(截距和系数)。随后使用保存的模型参数在稍后的时间点进行评分。我可以使用以下代码将模型保存在镶木地板文件中

LogisticRegressionModel trainedLRModel = logReg.fit(data);
trainedLRModel.write().overwrite().save("mypath");

当我加载模型进行评分时,出现以下错误。

LogisticRegression lr = new LogisticRegression();
lr.load("//saved_model_path");

Exception in thread "main" java.lang.NoSuchMethodException: org.apache.spark.ml.classification.LogisticRegressionModel.<init>(java.lang.String)
    at java.lang.Class.getConstructor0(Class.java:3082)
    at java.lang.Class.getConstructor(Class.java:1825)
    at org.apache.spark.ml.util.DefaultParamsReader.load(ReadWrite.scala:325)
    at org.apache.spark.ml.util.MLReadable$class.load(ReadWrite.scala:215)
    at org.apache.spark.ml.classification.LogisticRegression$.load(LogisticRegression.scala:672)
    at org.apache.spark.ml.classification.LogisticRegression.load(LogisticRegression.scala)

有没有办法训练和保存模型,然后再评估(得分)?我在 Java 中使用 Spark ML 2.1.0。

【问题讨论】:

    标签: apache-spark apache-spark-ml


    【解决方案1】:

    我在使用 pyspark 2.1.1 时遇到同样的问题,当我从 LogisticRegression 更改为 LogisticRegressionModel 时,一切正常。

    LogisticRegression.load("/model/path") # not works 
    
    LogisticRegressionModel.load("/model/path") # works well
    

    【讨论】:

    • pysparkStringIndexerStringIndexerModel 有同样的问题
    【解决方案2】:

    TL;DR使用LogisticRegressionModel.load

    load(path: String): LogisticRegressionModel 从输入路径读取一个ML实例,read.load(path)的快捷方式。


    事实上,从 Spark 2.0.0 开始,推荐使用 Spark MLlib 的方法,包括。 LogisticRegression 估计器,正在使用全新闪亮的Pipeline API

    import org.apache.spark.ml.classification._
    val lr = new LogisticRegression()
    
    import org.apache.spark.ml.feature._
    val tok = new Tokenizer().setInputCol("body")
    val hashTF = new HashingTF().setInputCol(tok.getOutputCol).setOutputCol("features")
    
    import org.apache.spark.ml._
    val pipeline = new Pipeline().setStages(Array(tok, hashTF, lr))
    
    // training dataset
    val emails = Seq(("hello world", 1)).toDF("body", "label")
    
    val model = pipeline.fit(emails)
    
    model.write.overwrite.save("mypath")
    val loadedModel = PipelineModel.load("mypath")
    

    【讨论】:

      猜你喜欢
      • 2019-01-06
      • 2022-10-04
      • 2021-09-07
      • 1970-01-01
      • 2019-10-16
      • 1970-01-01
      • 2020-12-22
      • 1970-01-01
      • 2018-01-26
      相关资源
      最近更新 更多