【问题标题】:issue in Decision Tree Classifier决策树分类器中的问题
【发布时间】:2018-04-08 22:02:04
【问题描述】:

我正在尝试运行决策树分类器,标签具有双模式,值从 -20 到 +20

import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import java.io.File`

     val dtModelPath = s"file:///home/parv/spark/examples/src/main/scala/org/apache/spark/examples/ml/ dtModel"

     val dtModel= { 
     val dtGridSearch = for (   
     dtImpurity<- Array("entropy", "gini");    
     dtDepth<- Array(3, 5))    
     yield {
     println(s"Training decision tree: impurity $dtImpurity,depth: $dtDepth")
     val dtModel = new DecisionTreeClassifier()
     .setFeaturesCol(idf.getOutputCol)  
     .setLabelCol("value")
     .setImpurity(dtImpurity)         
     .setMaxDepth(dtDepth)     
     .setMaxBins(10)          
     .setSeed(42)          
     .setCacheNodeIds(true)          
     .fit(trainData)
     val dtPrediction = dtModel.transform(testData)      
     val dtAUC = new BinaryClassificationEvaluator().setLabelCol("value").evaluate(dtPrediction)      
     println(s" DT AUC on test data: $dtAUC")      
     ((dtImpurity, dtDepth), dtModel, dtAUC)
     }    
     println(dtGridSearch.sortBy(-_._3).take(5).mkString("\n")) 
     val bestModel = dtGridSearch.sortBy(-_._3).head._2
     bestModel.write.overwrite.save(dtModelPath)
     bestModel
     }

我遇到了错误

raining 决策树: impurity entropy,depth: 3 [Stage 31346:=============> (47 + [阶段 31346:===============>(61+【阶段 31346:=======================>(87+【阶段 31346:=============================>(111+【阶段 31346:===================================> (135 + [阶段 31346:===========================================>(166+【阶段 31346:================================================= > (192 + 18/03/30 01:06:18 WARN 执行器:1 个块锁未被释放 TID = 63510:[rdd_62747_0] 18/03/30 01:06:18 错误执行程序:异常 在阶段 31353.0 (TID 63518) 中的任务 7.0 java.lang.IllegalArgumentException:要求失败:分类器是 给定标签无效的数据集 -6.0。标签必须是整数 范围 [0, 1, ..., 44),其中 numClasses=44。在 scala.Predef$.require(Predef.scala:224)

【问题讨论】:

    标签: java apache-spark apache-spark-ml


    【解决方案1】:

    看来你给分类器一个无效的标签。 上面写着Classifier was given dataset with invalid label -6.0. Labels must be integers in range [0, 1, ..., 44)

    我会检查标签,比如

    df.select($"labels").distinct.show(100)
    df.filter($"labels" < 0).show()
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2017-03-06
      • 2021-09-23
      • 2017-12-21
      • 2018-10-20
      • 1970-01-01
      • 2021-06-17
      • 2011-03-15
      • 2018-04-11
      相关资源
      最近更新 更多