【问题标题】:Fitting different models to each subset of data in R将不同的模型拟合到 R 中的每个数据子集
【发布时间】:2017-01-07 09:10:12
【问题描述】:

我有一个包含多个类的大型数据集。我的目标是为每个类拟合一个模型,然后预测结果并在一个方面为每个类可视化它们。

对于一个可重现的示例,我使用mtcars 创建了一些基本的东西。这适用于每个类的简单回归模型。

mtcars = data.table(mtcars)
model = mtcars[, list(fit = list(lm(mpg~disp+hp+wt))), keyby = cyl]
setkey(mtcars, cyl)
mtcars[model, pred := predict(i.fit[[1]], .SD), by = .EACHI]
ggplot(data = mtcars, aes(x = mpg, y = pred)) + geom_line() + facet_wrap(~cyl)

但是,我想尝试以下类似的方法,但它还不起作用。此尝试使用公式列表,但我也希望将不同的模型(一些 glms,一些树)发送到每个数据子集。

mtcars = data.table(mtcars)
factors = list(c("disp","wt"), c("disp"), c("hp"))
form = lapply(factors, function(x) as.formula(paste("mpg~",paste(x,collapse="+"))))
model = mtcars[, list(fit = list(lm(form))), keyby = cyl]
setkey(mtcars, cyl)
mtcars[model, pred := predict(i.fit[[1]], .SD), by = .EACHI]
ggplot(data = mtcars, aes(x = mpg, y = pred)) + geom_line() + facet_wrap(~cyl)

【问题讨论】:

  • 数据表真的有必要吗?
  • 否,但在大型数据集上速度更快,因此首选。 dplyr 也可以。
  • 我只是在暗示瓶颈将被预测,lm,ggplot。是list(fit = lapply(form, lm, data = .SD))你想要的
  • 是的,这是我犯的错误。

标签: r data.table


【解决方案1】:

lm() 也接受公式作为字符向量。因此,我只需将form 创建为:

form = lapply(factors, function(x) paste("mpg~", paste(x, collapse="+")))

并且,您需要提供正确的数据(使用内置特殊符号 .SD 对应于每个组):

model = mtcars[, list(fit=lapply(form, lm, data=.SD)), keyby=cyl]

对于每个cyl,循环遍历form,并且每次将相应的公式作为第一个参数传递给lmdata = .SD,其中.SD代表数据子集 本身就是一个data.table。您可以从vignettes 了解更多信息。


如果您还想在结果中包含公式,那么:

chform = unlist(form)
model = mtcars[, list(form=chform, fit=lapply(form, lm, data=.SD)), keyby = cyl]

HTH

PS:如果您打算使用 data.tables 在[...] 中使用update(),请阅读this post

【讨论】:

  • 这解决了我目前面临的问题。唯一的问题 - 我不明白为什么在不提供data=.SD 的情况下拟合一个常见模型时,它可以工作?
  • 公式对象还捕获创建它们的环境......这就是将要使用的内容。看看?lm
【解决方案2】:

这是一种方法,我们为每个模型设置predict 作为未评估列表,在data.table 对象中评估它们,gather 输出,并将其传递给ggplot

models = quote(list(
      predict(lm(form[[1]], .SD)),
      predict(lm(form[[2]], .SD)), 
      predict(lm(form[[3]], .SD))))

d <- mtcars
d[, c("est1", "est2", "est3") := eval(models), by = cyl]
d <- tidyr::gather(d, key = model, value = pred, est1:est3)

library(ggplot2)
ggplot(d, aes(x = mpg, y = pred)) + geom_line() + facet_grid(cyl ~ model)

输出:

【讨论】:

    【解决方案3】:

    我现在实际上正在做这件事,非常完美的时机。这将是一个“tidyverse”重的答案,但我真的很喜欢它的工作方式。

    purrr 有一些非常方便的map 函数,当与tibble 中的列表列结合使用时,这些函数会非常流畅。使用你的定义(我不想优化它)

    library(data.table)
    mtcars = data.table(mtcars)
    factors = list(c("disp","wt"), c("disp"), c("hp"))
    form = lapply(factors, function(x) as.formula(paste("mpg~",paste(x,collapse="+"))))
    

    提供函数列表,这些可以传递给purrr::invoke_map,它将参数列表(您拥有)应用于函数列表(在您的情况下,只是lm,但我怀疑可以扩展到其他人也)带有可选参数(在您的示例中为mtcars)。使用tibble,这些存储为一个整洁的data.frame-esque list,否则返回为lm对象

    library(tibble)
    library(purrr) 
    models <- tibble(fit = invoke_map(lm, form, data = mtcars))
    models
    #> # A tibble: 3 x 1
    #>          fit
    #>       <list>
    #>   1 <S3: lm>
    #>   2 <S3: lm>
    #>   3 <S3: lm>
    

    当你想对所有这些元素做一些事情时,超级有用的部分就出现了,比如提取拟合系数:

    map(models$fit, coefficients)
    #> [[1]]
    #> (Intercept)        disp          wt 
    #> 34.96055404 -0.01772474 -3.35082533 
    #> 
    #> [[2]]
    #> (Intercept)        disp 
    #> 29.59985476 -0.04121512 
    #> 
    #> [[3]]
    #> (Intercept)          hp 
    #> 30.09886054 -0.06822828 
    

    或重新检查使用的公式

    map(models$fit, formula)
    #> [[1]]
    #> mpg ~ disp + wt
    #> <environment: 0x0000000017ee73a8>
    #>   
    #>   [[2]]
    #> mpg ~ disp
    #> <environment: 0x0000000018392c58>
    #>   
    #>   [[3]]
    #> mpg ~ hp
    #> <environment: 0x0000000018471d18>
    

    此外,如果你想从模型中添加一些预测,这很容易使用broom::augment实现

    library(broom)
    models_with_predicts <- models %>% mutate(predict = map(fit, augment))
    models_with_predicts
    #> # A tibble: 3 x 2
    #>          fit                predict
    #>       <list>                 <list>
    #>   1 <S3: lm> <data.frame [32 x 10]>
    #>   2 <S3: lm>  <data.frame [32 x 9]>
    #>   3 <S3: lm>  <data.frame [32 x 9]>
    

    您可以通过unnest()ing 返回数据级别(带有预测),但这会合并您的所有数据(添加分组级别以保持拟合分开)

    library(tidyr)
    unnest(models_with_predicts, predict)
    
    #> # A tibble: 96 x 11
    #> mpg  disp    wt  .fitted   .se.fit     .resid       .hat   .sigma     .cooksd .std.resid    hp
    #> <dbl> <dbl> <dbl>    <dbl>     <dbl>      <dbl>      <dbl>    <dbl>       <dbl>      <dbl> <dbl>
    #>   1   21.0 160.0 2.620 23.34543 0.6075520 -2.3454326 0.04339369 2.933379 0.010222201 -0.8222164    NA
    #> 2   21.0 160.0 2.875 22.49097 0.6221836 -1.4909721 0.04550894 2.954135 0.004351414 -0.5232550    NA
    #> 3   22.8 108.0 2.320 25.27237 0.7326015 -2.4723669 0.06309504 2.928665 0.017217431 -0.8757799    NA
    #> 4   21.4 258.0 3.215 19.61467 0.5743205  1.7853334 0.03877647 2.948162 0.005241995  0.6243627    NA
    #> 5   18.7 360.0 3.440 17.05281 1.0943208  1.6471930 0.14078260 2.949120 0.020275438  0.6092882    NA
    #> 6   18.1 225.0 3.460 19.37863 0.6122393 -1.2786309 0.04406584 2.957872 0.003089406 -0.4483953    NA
    #> 7   14.3 360.0 3.570 16.61720 0.9897465 -2.3171997 0.11516157 2.931444 0.030948880 -0.8446199    NA
    #> 8   24.4 146.7 3.190 21.67120 0.9053245  2.7287988 0.09635365 2.918183 0.034431234  0.9842424    NA
    #> 9   22.8 140.8 3.150 21.90981 0.9165259  0.8901898 0.09875274 2.962885 0.003775416  0.3215070    NA
    #> 10  19.2 167.6 3.440 20.46305 0.9678618 -1.2630477 0.11012510 2.957375 0.008693734 -0.4590766    NA
    #> # ... with 86 more rows
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2013-10-07
      • 2019-10-07
      • 1970-01-01
      相关资源
      最近更新 更多