【问题标题】:ctree() - How to get the list of splitting conditions for each terminal node?ctree() - 如何获取每个终端节点的拆分条件列表?
【发布时间】:2014-02-21 23:19:03
【问题描述】:

我有来自ctree()party 包)的输出,如下所示。如何获取每个终端节点的拆分条件列表,如sns <= 0, dta <= 1; sns <= 0, dta > 1等?

1) sns <= 0; criterion = 1, statistic = 14655.021
  2) dta <= 1; criterion = 1, statistic = 3286.389
   3)*  weights = 153682 
  2) dta > 1
   4)*  weights = 289415 
1) sns > 0
  5) dta <= 2; criterion = 1, statistic = 1882.439
   6)*  weights = 245457 
  5) dta > 2
   7) dta <= 6; criterion = 1, statistic = 1170.813
     8)*  weights = 328582 
   7) dta > 6

谢谢

【问题讨论】:

    标签: r decision-tree party


    【解决方案1】:

    CtreePathFunc 函数重写为更像 Hadley-verse(我认为更易于理解)的方式。还处理分类变量。

    library(magrittr)
    readSplitter <- function(nodeSplit){
      splitPoint <- nodeSplit$splitpoint
      if("levels" %>% is_in(splitPoint %>% attributes %>% names)){
        splitPoint %>% attr("levels") %>% .[splitPoint]
      }else{
        splitPoint %>% as.numeric
      }
    }
    
    hasWeigths <- function(ct, path, terminalNode, pathNumber){
      ct %>%
        nodes(pathNumber %>% equals(path %>% length) %>% ifelse(terminalNode, path[pathNumber + 1]) ) %>%
        .[[1]] %>% use_series("weights") %>% as.logical %>% which
    }
    
    dataFilter <- function(ct, dts, path, terminalNode, pathNumber){
      whichWeights <- hasWeigths(ct, path, terminalNode, pathNumber)
      nodes(ct, path[pathNumber])[[1]][[5]] %>%
        buildDataFilter(dts, whichWeights)
    }
    
    buildDataFilter <- function(nodeSplit, ...) UseMethod("buildDataFilter")
    
    buildDataFilter.nominalSplit <-
      function(nodeSplit, dts, whichWeights){
        varName <- nodeSplit$variableName
        includedLevels <- dts[ whichWeights
                              ,varName] %>% unique
        paste( varName, "=="
              ,includedLevels %>% paste(collapse = ", ") %>% paste0("{", ., "}"))
      }
    
    buildDataFilter.orderedSplit <-
      function(nodeSplit, dts, whichWeights){
        varName <- nodeSplit$variableName
        splitter <- nodeSplit %>% readSplitter
    
        dts[ whichWeights
            ,varName] %>%
              is_weakly_less_than(splitter) %>%
              all %>%
              ifelse("<=" ,">") %>%
              paste(varName, ., splitter)
    }
    
    readTerminalNodePaths <- function (ct, dts) {
    
      nodeWeights <- function(Node) nodes(ct, Node)[[1]]$weights
      sgmnts <- ct %>% where %>% unique
      nodesFirstTreeWeightIsOne <- function(node) nodes(ct, node)[[1]][2][[1]] == 1
    
      # Take the inner nodes smaller than the selected terminal node
      innerNodes <-
        function(Node) setdiff( 1:(Node - 1)
                               ,sgmnts[sgmnts < Node])
      pathForTerminalNode <- function(terminalNode){
        innerNodes(terminalNode) %>%
          sapply(function(innerNode){
            if(any(nodeWeights(terminalNode) & nodesFirstTreeWeightIsOne(innerNode))) innerNode
           }) %>%
          unlist
      }
    
      # Find the splits criteria
      sgmnts %>% sapply(function(terminalNode){ #
    
        path <- terminalNode %>% pathForTerminalNode
    
        path %>% length %>% seq %>%
          sapply(function(nodeNumber){
            dataFilter(ct, dts, path, terminalNode, nodeNumber)
           }, simplify = FALSE) %>%
          unlist %>% paste(collapse = " & ") %>%
          data.frame(Node = terminalNode, Path = .)
    
      }, simplify = FALSE) %>%
        Reduce(f = rbind)
    }
    

    测试

    shiftFirstPart <- function(vctr, divideBy, proportion = .5){
        vctr[vctr %>% length %>% multiply_by(proportion) %>% round %>% seq] %<>% divide_by(divideBy)
      vctr
    }
    set.seed(11)
    n <- 13000
    gdt <- 
      data.frame( is_buyer = runif(n) %>% shiftFirstPart(1.5) %>% round %>% factor(labels = c("no", "yes"))
                 ,age = runif(n) %>% shiftFirstPart(1.5) %>%
                   cut(breaks = c(0, .3, .6, 1), include_lowest = TRUE, ordered_result = TRUE, labels = c("low", "mid", "high"))
                 ,city = runif(n) %>% shiftFirstPart(1.5) %>%
                   cut(breaks = c(0, .3, .6, 1), include_lowest = TRUE, labels = c("Chigaco", "Boston", "Memphis"))
                 ,point = runif(n) %>% shiftFirstPart(1.2)
                 )
    
    gct <- ctree( is_buyer ~ ., data = gdt)
    readTerminalNodePaths(gct, gdt)
    

    【讨论】:

    • 这是一个非常酷的功能,我不确定我是否完全理解它是如何工作的,但是有没有一种简单的方法来修改它来获取 all 节点的路径而不是只是终端的?
    • 抱歉这个愚蠢的问题,但是哪个库包含函数'where'?
    • @Manuel Chirouze, party.
    【解决方案2】:

    如果您使用ctree() 的新推荐partykit 实现而不是旧的party 包,那么您可以使用函数.list.rules.party()。这还没有正式导出,但可以用来提取所需的信息。

    library("partykit")
    airq <- subset(airquality, !is.na(Ozone))
    ct <- ctree(Ozone ~ ., data = airq)
    partykit:::.list.rules.party(ct)
    ##                                      3                                      5 
    ##             "Temp <= 82 & Wind <= 6.9" "Temp <= 82 & Wind > 6.9 & Temp <= 77" 
    ##                                      6                                      8 
    ##  "Temp <= 82 & Wind > 6.9 & Temp > 77"             "Temp > 82 & Wind <= 10.3" 
    ##                                      9 
    ##              "Temp > 82 & Wind > 10.3" 
    

    【讨论】:

      【解决方案3】:

      由于我需要这个函数,但对于分类数据,我或多或少地回答了@JoãoDaniel 的问题(我只测试了分类预测变量),下一个函数:

      # returns string w/o leading or trailing whitespace
      # http://stackoverflow.com/questions/2261079/how-to-trim-leading-and-trailing-whitespace-in-r
      trim <- function (x) gsub("^\\s+|\\s+$", "", x)
      getVariable <- function (x) sub("(.*?)[[:space:]].*", "\\1", x)
      getSimbolo <- function (x) sub("(.*?)[[:space:]](.*?)[[:space:]].*", "\\2", x)
      
      getReglaFinal = function(elemento) {        
          x = as.data.frame(strsplit(as.character(elemento),";"))
          Regla = apply(x,1, trim)
          Regla = data.frame(Regla)
          indice = as.numeric(rownames(Regla))
          variable = apply(Regla,1, getVariable)
          simbolo = apply(Regla,1, getSimbolo)
      
          ReglaRaw = data.frame(Regla,indice,variable,simbolo)
          cols <- c( 'variable' , 'simbolo' )
          ReglaRaw$tipo_corte <- apply(  ReglaRaw[ , cols ] ,1 , paste , collapse = "" )
          #print(ReglaRaw)
          cortes = unique(ReglaRaw$tipo_corte)
          #print(cortes)
          ReglaFinal = ""
          for(i in 1:length(cortes)){
              #print("------------------------------------")
              #print(cortes[i])
              #print("ReglaRaw econtrada")
              #print(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])
              maximo = max(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])
              #print(maximo)
              tmp = as.character(ReglaRaw$Regla[ReglaRaw$indice==maximo])
              if(ReglaFinal==""){
                  ReglaFinal = tmp
              }else{
                  ReglaFinal = paste(ReglaFinal,tmp,sep="; ",collapse="; ")
              }
          }
          return(ReglaFinal)
      }#getReglaFinal
      
      CtreePathFuncAllCat <- function (ct) {
      
        ResulTable <- data.frame(Node = character(), Path = character())
      
        for(Node in unique(where(ct))){
      
          # Taking all possible non-Terminal nodes that are smaller than the selected terminal node
          NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])
      
          # Getting the weigths for that node
          NodeWeights <- nodes(ct, Node)[[1]]$weights
      
          # Finding the path
          Path <- NULL
          for (i in NonTerminalNodes){
              if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
          }
      
          # Finding the splitting creteria for that path
          Path2 <- SB <- NULL
      
          variablesNombres <- array()
          variablesPuntos <- list()
      
          for(i in 1:length(Path)){
              n <- nodes(ct, Path[i])[[1]]
      
              if(i == length(Path)) {
                  nextNodeID = Node
              } else {
                  nextNodeID = Path[i+1]
              }       
      
              vec_puntos  = as.vector(n[[5]]$splitpoint)
              vec_nombre  = n[[5]]$variableName
              vec_niveles = attr(n[[5]]$splitpoint,"levels")
      
              index = 0
      
              if((length(vec_puntos)!=length(vec_niveles)) && (length(vec_niveles)!=0) ){
                  index = vec_puntos
                  vec_puntos = vector(length=length(vec_niveles))
                  vec_puntos[index] = TRUE
              }
      
              if(length(vec_niveles)==0){
                  index = vec_puntos
                  vec_puntos = n[[5]]$splitpoint
              }
      
              if(index==0){
                  if(nextNodeID==n$right$nodeID){
                      vec_puntos = !vec_puntos
                  }else{
                      vec_puntos = !!vec_puntos
                  }
                  if(i != 1) {
                      for(j in 1:(length(Path)-1)){
                          if(length(variablesNombres)>=j){
                              if( variablesNombres[j]==vec_nombre){
                                  vec_puntos = vec_puntos*variablesPuntos[[j]]
                              }
                          }
                      }
                      vec_puntos = vec_puntos==1
                  }   
                  SB = "="
              }else{
                  if(nextNodeID==n$right$nodeID){
                      SB = ">"
                  }else{
                      SB = "<="
                  }
      
              }
      
              variablesPuntos[[i]] = vec_puntos       
              variablesNombres[i] = vec_nombre
      
              if(length(vec_niveles)==0){
                  descripcion = vec_puntos
              }else{
                  descripcion = paste(vec_niveles[vec_puntos],collapse=", ")
              }
              Path2 <- paste(c(Path2, paste(c(variablesNombres[i],SB,"{",descripcion, "}"),collapse=" ")
                              ),
                             collapse = "; ")
          }
      
          # Output
          ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
        }
      
          we = weights(ct)
          c0 = as.matrix(where(ct))
          c3 = sapply(we, function(w) sum(w))
          c3 = as.matrix(unique(cbind(c0,c3)))
          Counts = as.matrix(c3[,2])
          c2 = drop(Predict(ct))
          Means = as.matrix(unique(c2))
      
          ResulTable = data.frame(ResulTable,Means,Counts)
          ResulTable  = ResulTable[ order(ResulTable$Means) ,]
      
          ResulTable$TruePath =  apply(as.data.frame(ResulTable$Path),1, getReglaFinal)
      
          ResulTable2 = ResulTable
      
          ResulTable2$SQL <- paste("WHEN ",gsub("\\'([-+]?([0-9]*\\.[0-9]+|[0-9]+))\\'", "\\1",gsub("\\, ", "','", gsub(" \\}", "')", gsub("\\{ ", "('", gsub("\\;", " AND ", ResulTable2$TruePath)))))," THEN ")
      
          cols <- c( 'SQL' , 'Node' )
          ResulTable2$SQL <- apply(  ResulTable2[ , cols ] ,1 , paste , collapse = "'Nodo " )
      
          ResulTable2$SQL <- gsub("THEN'", "THEN '", gsub(" '", "'",  paste(ResulTable2$SQL,"'")))
      
          ResultadoFinal = list()
      
          ResultadoFinal$PreTable = ResulTable
          ResultadoFinal$Table = ResulTable
          ResultadoFinal$Table$Path = ResultadoFinal$Table$TruePath
          ResultadoFinal$Table$TruePath = NULL
          ResultadoFinal$SQL = paste(" CASE ",paste(ResulTable2$SQL,sep="",collapse=" ")," END ",collapse="")
      
          return(ResultadoFinal)
      }#CtreePathFuncAllCat
      

      这是一个测试:

      library(party)
      #With ordered factors
      TreeModel1 = ctree(PB~ME+SYMPT+HIST+BSE+DECT, data = mammoexp)
      Result2 <- CtreePathFuncAllCat(TreeModel1)
      Result2
      ##$PreTable
      ##  Node                                                Path    Means Counts
      ##3    7    DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316    114
      ##2    6   DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000    175
      ##1    4  DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905    105
      ##4    3 DECT <= { Somewhat likely }; DECT <= { Not likely } 9.833333     18
      ##                                          TruePath
      ##3   DECT > { Somewhat likely }; SYMPT > { Disagree }
      ##2  DECT > { Somewhat likely }; SYMPT <= { Disagree }
      ##1 DECT <= { Somewhat likely }; DECT > { Not likely }
      ##4                             DECT <= { Not likely }
      ##
      ##$Table
      ##  Node                                               Path    Means Counts
      ##3    7   DECT > { Somewhat likely }; SYMPT > { Disagree } 6.526316    114
      ##2    6  DECT > { Somewhat likely }; SYMPT <= { Disagree } 7.640000    175
      ##1    4 DECT <= { Somewhat likely }; DECT > { Not likely } 8.161905    105
      ##4    3                             DECT <= { Not likely } 9.833333     18
      ##
      ##$SQL
      ##[1] " CASE  WHEN  DECT > ('Somewhat likely') AND  SYMPT > ('Disagree')  THEN 'Nodo 7' WHEN  DECT > ('Somewhat likely') AND  SYMPT <= ('Disagree')  THEN 'Nodo 6' WHEN  DECT <= ('Somewhat likely') AND  DECT > ('Not likely')  THEN 'Nodo 4' WHEN  DECT <= ('Not likely')  THEN 'Nodo 3'  END "
      
      
      #With unordered factors
      TreeModel2 = ctree(count~spray, data = InsectSprays)
      plot(TreeModel2, type="simple")
      Result2 <- CtreePathFuncAllCat(TreeModel2)
      Result2
      ##$PreTable
      ##Node                                  Path     Means Counts            TruePath
      ##2    5 spray = { C, D, E }; spray = { C, E }  2.791667     24    spray = { C, E }
      ##3    4    spray = { C, D, E }; spray = { D }  4.916667     12       spray = { D }
      ##1    2                   spray = { A, B, F } 15.500000     36 spray = { A, B, F }
      ##
      ##$Table
      ##Node                Path     Means Counts
      ##2    5    spray = { C, E }  2.791667     24
      ##3    4       spray = { D }  4.916667     12
      ##1    2 spray = { A, B, F } 15.500000     36
      ##
      ##$SQL
      ##[1] " CASE  WHEN  spray = ('C','E')  THEN 'Nodo 5' WHEN  spray = ('D')  THEN 'Nodo 4' WHEN  spray = ('A','B','F')  THEN 'Nodo 2'  END "
      
      #With continuous variables
      airq <- subset(airquality, !is.na(Ozone))
      TreeModel3 <- ctree(Ozone ~ ., data = airq,  controls = ctree_control(maxsurrogate = 3))
      Result2 <- CtreePathFuncAllCat(TreeModel3)
      Result2
      ##$PreTable
      ##  Node                                           Path    Means Counts
      ##1    5 Temp <= { 82 }; Wind > { 6.9 }; Temp <= { 77 } 18.47917     48
      ##3    6  Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286     21
      ##4    9                 Temp > { 82 }; Wind > { 10.3 } 48.71429      7
      ##2    3                Temp <= { 82 }; Wind <= { 6.9 } 55.60000     10
      ##5    8                Temp > { 82 }; Wind <= { 10.3 } 81.63333     30
      ##                                     TruePath
      ##1                Temp <= { 77 }; Wind > { 6.9 }
      ##3 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 }
      ##4                Temp > { 82 }; Wind > { 10.3 }
      ##2               Temp <= { 82 }; Wind <= { 6.9 }
      ##5               Temp > { 82 }; Wind <= { 10.3 }
      ##
      ##$Table
      ##  Node                                          Path    Means Counts
      ##1    5                Temp <= { 77 }; Wind > { 6.9 } 18.47917     48
      ##3    6 Temp <= { 82 }; Wind > { 6.9 }; Temp > { 77 } 31.14286     21
      ##4    9                Temp > { 82 }; Wind > { 10.3 } 48.71429      7
      ##2    3               Temp <= { 82 }; Wind <= { 6.9 } 55.60000     10
      ##5    8               Temp > { 82 }; Wind <= { 10.3 } 81.63333     30
      ##
      ##$SQL
      ##[1] " CASE  WHEN  Temp <= (77) AND  Wind > (6.9)  THEN 'Nodo 5' WHEN  Temp <= (82) AND  Wind > (6.9) AND  Temp > (77)  THEN 'Nodo 6' WHEN  Temp > (82) AND  Wind > (10.3)  THEN 'Nodo 9' WHEN  Temp <= (82) AND  Wind <= (6.9)  THEN 'Nodo 3' WHEN  Temp > (82) AND  Wind <= (10.3)  THEN 'Nodo 8'  END "
      

      更新!现在该函数支持分类变量和数值变量的混合!

      【讨论】:

      • 很好用,但是,它似乎只适用于分类变量:当我在 airct 树 CtreePathFuncAllCat(ct) 的结果上尝试这个时,它返回拆分字段,但不返回拆分标准。知道如何获取分类变量和连续变量的路径吗?
      • @clevelandfrowns 我更新了函数,现在可以处理连续和分类数据。
      【解决方案4】:

      这个函数应该可以完成这项工作

       CtreePathFunc <- function (ct, data) {
      
        ResulTable <- data.frame(Node = character(), Path = character())
      
        for(Node in unique(where(ct))){
        # Taking all possible non-Terminal nodes that are smaller than the selected terminal node
        NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])
      
      
        # Getting the weigths for that node
        NodeWeights <- nodes(ct, Node)[[1]]$weights
      
      
        # Finding the path
        Path <- NULL
        for (i in NonTerminalNodes){
          if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
        }
      
        # Finding the splitting creteria for that path
        Path2 <- SB <- NULL
      
        for(i in 1:length(Path)){
          if(i == length(Path)) {
            n <- nodes(ct, Node)[[1]]
          } else {n <- nodes(ct, Path[i + 1])[[1]]}
      
          if(all(data[which(as.logical(n$weights)), as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))])] <= as.numeric(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))){
            SB <- "<="
          } else {SB <- ">"}
          Path2 <- paste(c(Path2, paste(as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[length(unlist(nodes(ct,Path[i])[[1]][[5]]))]),
                                       SB,
                                       as.character(unlist(nodes(ct,Path[i])[[1]][[5]])[3]))),
                         collapse = ", ")
        }
      
        # Output
        ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
        }
        return(ResulTable)
      }
      

      测试

      library(party)
      airq <- subset(airquality, !is.na(Ozone))
      ct <- ctree(Ozone ~ ., data = airq,  controls = ctree_control(maxsurrogate = 3))
      Result <- CtreePathFunc(ct, airq)
      Result 
      
      ##   Node                               Path
      ## 1    5 Temp <= 82, Wind > 6.9, Temp <= 77
      ## 2    3            Temp <= 82, Wind <= 6.9
      ## 3    6  Temp <= 82, Wind > 6.9, Temp > 77
      ## 4    9             Temp > 82, Wind > 10.3
      ## 5    8            Temp > 82, Wind <= 10.3
      

      【讨论】:

      • 耗时较长,但反应非常好。而且您忘记将“airq”矩阵作为变量。
      • 谢谢,@Galled。已编辑。我也忘了library(party)。这是我在 SO 中的第一个答案之一,所以那里有点菜鸟
      • 这个函数是否有任何更新版本也可以处理分类解释变量? @DavidArenburg
      • @JoãoDaniel,我没有写过。也许发布一个新问题,看看是否有人可以详细说明,因为我不确定我是否有时间在新的未来写一个
      • @JoãoDaniel 我做了一个。
      猜你喜欢
      • 2015-07-12
      • 2016-02-09
      • 2021-07-17
      • 2015-08-19
      • 2015-04-22
      • 2020-12-10
      • 2012-12-07
      • 2015-09-09
      • 2012-07-12
      相关资源
      最近更新 更多