【问题标题】:referencing rpart terminal nodes in r在 r 中引用 rpart 终端节点
【发布时间】:2016-06-10 05:11:48
【问题描述】:

我是 R(和 rpart)的新手。我有车辆模型数据(约 400 个模型)。我正在使用 rpart 将这些分组为具有相似车辆维修成本的较小数量(例如 5-10 组)。我已成功运行 rpart 并拥有这些分组。

fit <- rpart(repairs ~ model, data=data, method='anova', control=rpart.control(minsplit=2,minbucket=1,cp=.0005))    

假设每个终端节点中大约有 40-80 个模型。有没有办法让我创建一个引用终端节点中的值的公式。假设 data$model 包含所有模型名称(并且是我正在尝试执行的自变量:

data$modelgroup <- data$model
data$modelgroup[data$modelgroup %in% terminal node 1] <- 'Group1'
data$modelgroup[data$modelgroup %in% terminal node 2] <- 'Group2'
and so on for the rest of the groups

另外,如果有一种方法可以做到这一点,而不必为每个组编写一行代码,那就太好了。

我知道我可以打印树并从终端节点手动复制文本并以这种方式完成,但这非常低效。

提前感谢您的帮助!

根据下面的要求,我在下面添加了一个可重现的示例。

data <- read.csv("rpart_example.csv")
data

data[,1:2]

   Model Amount
1      a      1
2      a      1
3      a      1
4      b      1
5      b      1
6      b      1
7      c      2
8      c      2
9      c      2
10     d      2
11     d      2
12     d      2
13     e      3
14     e      3
15     e      3
16     f      4
17     f      4
18     f      4

fit <- rpart(Amount ~ Model, data=data, method='anova', 
          control=rpart.control(minsplit=2,minbucket=1,cp=.0005))
print(fit)

n= 18 

node), split, n, deviance, yval
* denotes terminal node

1) root 18 20.5 2.166667  
2) Model=a,b,c,d 12  3.0 1.500000  
4) Model=a,b 6  0.0 1.000000 *
  5) Model=c,d 6  0.0 2.000000 *
  3) Model=e,f 6  1.5 3.500000  
6) Model=e 3  0.0 3.000000 *
  7) Model=f 3  0.0 4.000000 *

# create a variable modelgroup that groups models per terminal nodes from rpart     

# I can do this manually as below
# is there a way for me to automate this assignment?

data$modelgroup <- as.character(data$Model)

# per rpart output, a&b are grouped into one terminal node
data$modelgroup[data$modelgroup %in% c('a','b')] <- 'Group1'    

# per rpart output, c&d are grouped into the second terminal node
data$modelgroup[data$modelgroup %in% c('c','d')] <- 'Group2'

# per rpart, e is the third terminal node
data$modelgroup[data$modelgroup == 'e'] <- 'Group3'

# per rpart, f is the fourth terminal node
data$modelgroup[data$modelgroup == 'f'] <- 'Group4'

【问题讨论】:

  • 如果您提供一个最小的 [可重现的示例] 会更容易为您提供帮助。包括一些示例数据并为该输入指定所需的输出。
  • 我不确定我是否可以提供一个可重现的示例,但作为说明,例如在运行 rpart 和 print(fit) 之后,其中一个终端节点包含 FordTaurus、ChevyMalibu ... 和 40+更多型号名称。假设我想将此终端节点中列出的所有模型称为“组 1”。我基本上想要代码说明模型名称是否在此列表中,将其称为“Group 1”并为每个终端节点执行此操作。
  • 提供假设的解决方案并不容易。或许您可以修改 rpart 帮助页面中的示例以制作可重现的示例。
  • 抱歉耽搁了。直到现在我都无法回到这个话题。根据您的要求,我添加了一个可重现的示例。我将它添加到我原来的帖子的底部。我希望这会有所帮助。

标签: r rpart


【解决方案1】:

rpart 对象中,您要查找的信息基本上很容易存储在$where 元素中。它为您提供了每个观察被分配到的节点号:

table(fit$where, data$modelgroup)
##     Group1 Group2 Group3 Group4
##   3      6      0      0      0
##   4      0      6      0      0
##   6      0      0      3      0
##   7      0      0      0      3

当然,您也可以将节点 ID(3、4、6、7)切换为因子或字符变量,例如 factor(fit$where, levels = c(3, 4, 6, 7), labels = paste0("Group", 1:4)) 或类似的东西。

如果您想通过简单统一的界面对新数据执行此操作,可以将您的 rpart 对象转换为 party 包中的 party 对象@:

library("partykit")
fit2 <- as.party(fit)

print(fit2)plot(fit2)的统一方法可用,predict(fit2, ...)有不同的类型:

table(predict(fit2, newdata = data, type = "node"), data$modelgroup)
##     Group1 Group2 Group3 Group4
##   3      6      0      0      0
##   4      0      6      0      0
##   6      0      0      3      0
##   7      0      0      0      3

这将返回与上述相同的结果,但也可以轻松应用于其他 newdata

【讨论】:

  • 感谢您的回复。我很感激!我尝试使用建议的代码 (table(fit$where, data$modelgroup)),但收到一条错误消息。我也对这样做的目的感到困惑。在我上面的示例中,如何使用它为模型 a 和 b 分配“组 1”给新变量 data$modelgroup(因为它们在第一个终端节点中)?
  • 该表应该显示来自fit$where的分组和手动构造的modelgroup重合,即提供相同的组信息。如果没有重现错误消息的独立示例,我无法评论错误。我使用的示例是您发布的简单 18 观察数据集。
  • 感谢您的澄清。我让它工作。感谢您的帮助!
  • 我调整了您的回复并使用了以下代码。再次感谢! data$modelgroup
猜你喜欢
  • 2017-08-09
  • 2015-05-31
  • 2017-06-03
  • 2021-11-20
  • 2012-04-05
  • 1970-01-01
  • 2016-02-09
  • 2015-10-28
  • 1970-01-01
相关资源
最近更新 更多