【问题标题】:R randomForest: prediction values for non-terminals?R randomForest:非终端的预测值?
【发布时间】:2015-07-24 23:38:24
【问题描述】:

R randomForest 的文档与 getTree() 方法的输出之间存在差异。

documentation 声明 getTree() 中预测字段的值对于非终端节点应为零:

prediction:对节点的预测;如果节点不是终端则为 0

这是分类方法的情况,但是在回归方法中有非零连续值:

> library(randomForest)
> 
> x  <- data.frame(matrix(rnorm(20), nrow=10))
> y  <- rnorm(10)
> 
> model <- randomForest(x,y)
> getTree(model,k=1)
  left daughter right daughter split var split point status prediction
1             2              3         2  0.49239435     -3 -0.1212934
2             4              5         2  0.09046437     -3 -0.4871480
3             0              0         0  0.00000000     -1  1.3421250
4             6              7         2 -0.61841853     -3 -0.2501163
5             0              0         0  0.00000000     -1 -1.1982434
6             0              0         0  0.00000000     -1 -0.8738258
7             0              0         0  0.00000000     -1  0.9973027

这些非终端节点的预测值是否用于预测?如果不是,他们的目的是什么?

在单独但相关的注释中,status 字段也不符合回归方法的文档:

status:status是节点终端(-1)还是不是(1)

如前所述,分类方法似乎完全遵循文档:

> y_bin <- as.factor(y>0)
> model <- randomForest(x,y_bin)
> getTree(model,k=1)
  left daughter right daughter split var split point status prediction
1             2              3         2  -0.6184185      1          0
2             0              0         0   0.0000000     -1          1
3             4              5         1  -0.3887568      1          0
4             0              0         0   0.0000000     -1          1
5             0              0         0   0.0000000     -1          2

【问题讨论】:

    标签: r random-forest


    【解决方案1】:

    我已经用一个测试用例确认了 randomForest 回归器的预测不依赖于预测字段中的决策节点值:

    > x     <- data.frame(matrix(rnorm(20), nrow=10))
    > y     <- rnorm(10)
    > 
    > model <- randomForest(x,y,ntree=1)
    > getTree(model,k=1)
      left daughter right daughter split var split point status prediction
    1             2              3         2  -0.1314179     -3 -0.1901029
    2             0              0         0   0.0000000     -1 -1.6884260
    3             4              5         2   1.0801034     -3  0.1844779
    4             0              0         0   0.0000000     -1 -0.0447021
    5             0              0         0   0.0000000     -1  0.4136579
    > 
    > test <- data.frame(X1=1,X2=1)
    > predict(model,test)
             1 
    -0.0447021 
    

    如果来自数据点的值大于隔断。

    【讨论】:

      猜你喜欢
      • 2018-05-11
      • 2016-04-24
      • 2019-04-27
      • 2017-01-11
      • 2021-03-16
      • 2021-10-28
      • 2014-11-29
      • 2016-11-23
      • 1970-01-01
      相关资源
      最近更新 更多