【问题标题】:How R2 and RMSE are calculated in cross-validation of pls Rpls R 的交叉验证中如何计算 R2 和 RMSE
【发布时间】:2018-10-02 22:18:26
【问题描述】:

我正在使用 Mevik (2007) 的 pls R 包进行偏最小二乘回归。 10折交叉验证模型如下:

pls.fa <- plsr(FA ~ ., ncomp = xcomp,scale = TRUE, validation = "CV", segments = 10,jackknife =TRUE, data=train)

然后,我可以打印出精度,例如 R2 或 RMSE 使用:

R2(pls.fa,ncomp=1:xcomp)

其中 xcomp 是组件的最佳数量。 例如,R2 的结果如下所示:

Intercept)      1 comps      2 comps      3 comps      4 comps      5 comps      6 comps      7 comps      8 comps      9 comps  
  -0.009828     0.551053     0.570584     0.574790     0.580414     0.583354     0.585812     0.580690     0.581536     0.595441  
   10 comps  
   0.596096  

我的问题是:这个交叉验证产生的 R2 是多少,是 10 倍的平均值吗?

谢谢

【问题讨论】:

    标签: r cross-validation pls


    【解决方案1】:

    我进行了一些测试,似乎 pls::R2pls::RMSEP 返回的 R2RMSE 不是 10 折的平均统计数据。它们是通过从所有 10 个 CV 折叠中提取预测并将它们与观察到的结果进行比较,同时使用所有保留样本来计算的:

    这是一个例子:

    library(pls)
    

    用内置的纱线数据集拟合模型:

    data(yarn)
    pls.fa <- plsr(density ~ NIR,
                   data = yarn,
                   ncomp = 6,
                   scale = TRUE,
                   validation = "CV",
                   segments = 10,
                   jackknife = TRUE)
    

    我将使用等效的caret 函数进行比较

    以下代码返回使用前 1:6 分量获得的 RMSE:

    pls::RMSEP(pls.fa, ncomp = 1:6, estimate = "CV", intercept = FALSE) 
    #output
    1 comps  2 comps  3 comps  4 comps  5 comps  6 comps  
     8.4692   2.5553   1.9430   1.0151   0.7399   0.5801  
    

    以数值向量的形式提取 RMSE:

    unlist(lapply(1:6, function(x) pls::RMSEP(pls.fa,
                                              ncomp = 1:6,
                                              estimate = "CV",
                                              intercept = FALSE)$val[,,x]))
    

    让我们使用所有数据将输出与caret::RMSE 进行比较:

    all.equal(
      unlist(lapply(1:6, function(x) caret::RMSE(pls.fa$validation$pred[,,x],
                                                 yarn$density))),
      unlist(lapply(1:6, function(x) pls::RMSEP(pls.fa,
                                                ncomp = 1:6,
                                                estimate = "CV",
                                                intercept = FALSE)$val[,,x])))
    #output  
    TRUE
    

    所以RMSEP 是通过使用所有保留预测来计算的。

    相当于R2:

    all.equal(
      unlist(lapply(1:6, function(x) caret::R2(pls.fa$validation$pred[,,x],
                                               yarn$density,
                                               form = "traditional"))),
      unlist(lapply(1:6, function(x) pls::R2(pls.fa,
                                             ncomp = 1:6,
                                             estimate = "CV",
                                             intercept = FALSE)$val[,,x])))
    #output  
    TRUE
    

    编辑:回答评论中的问题:

    哪种方法更好地在折叠上平均 RMSE,或者从折叠中提取所有预测并计算一个 RMSE:

    在我看来,任何一种方式都很好,只是在比较模型时需要在计算中保持一致。考虑以下示例:

    set.seed(1)
    true <- rnorm(100)
    fold <- sample(1:10, size = 100, replace = T)
    pred <- rnorm(100)
    
    z <- data.frame(true, pred, fold)
    
    library(tidyverse)
    
    z %>%
      group_by(fold) %>%
      summarise(rmse = caret::RMSE(true, pred)) %>%
      pull(rmse) %>%
      mean
    #ouput
     1.479923
        
    z %>%
      summarise(rmse = caret::RMSE(true, pred)) %>%
      pull(rmse) 
    #ouput
    1.441471
    

    与提取所有预测并计算 RMSE 相比,此处对折叠进行平均给出了更悲观的结果。

    使用与 set.seed(2) 相同的代码:

    平均折叠次数:1.442483 拉所有:1.500432

    现在平均折叠次数更加乐观

    因此,一种方法并不总是更乐观。

    【讨论】:

    • 这对我来说是非常有用的信息,非常感谢@missuse。这是否意味着每次模型进行交叉验证时,它都会保存预测值,并且在所有 10 次之后将它们拉到一起并将它们与观察值相关联?
    • @Phuong Ho 很高兴为您提供帮助。是的,这是正确的,所有来自 CV 的预测都被提取并与观察值进行比较。编辑答案更容易理解。但是,由于输出 pls 对象包含折叠中的观察索引,您可以使用自定义函数计算您想要的任何指标。
    • @missue:您是否认为像上面的方法一样计算 R2(即所有保留的 cv 预测及其实际值比计算每个 CV 折叠的 R 并平均它们更好。我找到了所有保留的产生比所有 10 倍的平均 R2 更高的 R2。
    • @hn.phuong 我通常计算每个折叠的 R2 并对它们进行平均,因为这种方式提供了一些方差估计。但我认为这并不重要,更重要的是在比较模型时保持一致。无论哪种方式都可以正常工作。
    • @missuse:我问这个问题是因为我正在使用 h2o.deeplearning 做相同的模型,我平均了每个折叠的 R2,当然我们也可以查看方差。但是,正如我在之前的评论中提到的,我发现这个均值 R2 比保持 R2 小,这可能会导致如此乐观的结果
    猜你喜欢
    • 2018-11-18
    • 1970-01-01
    • 2014-07-16
    • 2021-11-24
    • 2023-04-03
    • 1970-01-01
    • 2019-04-24
    • 2017-03-31
    • 2017-12-18
    相关资源
    最近更新 更多