【问题标题】:R Faster way to create lag by entire group using data.tableR使用data.table更快地按整个组创建滞后的方法
【发布时间】:2021-01-31 17:58:13
【问题描述】:

我有一个简单的data.table如下-

  ID = c(rep("A", 1000), rep("B", 1000), rep("C", 1000), rep("D", 1000))
  val = c("a", "a", "a", "b", "b", "c", "c","d","d","d","d","e","e","f","f","g","g","g","g","g")

  dt = data.table(ID, val)

我想在此 data.table 中添加一个新列,该列将按组 ID 延迟 val

这是预期的输出

> head(dt, 20)
     ID val val_lag
 1:  A   a    <NA>
 2:  A   a    <NA>
 3:  A   a    <NA>
 4:  A   b       a
 5:  A   b       a
 6:  A   c       b
 7:  A   c       b
 8:  A   d       c
 9:  A   d       c
10:  A   d       c
11:  A   d       c
12:  A   e       d
13:  A   e       d
14:  A   f       e
15:  A   f       e
16:  A   g       f
17:  A   g       f
18:  A   g       f
19:  A   g       f
20:  A   g       f

我目前使用的解决方案是 -

dt[, val_lag := with(rle(val), rep(c(NA, head(values, -1)), lengths)), by = ID]

但是,此解决方案在实际数据集上非常慢,该数据集非常大并且有数百万行。有没有更快的方法来解决这个问题?

以下是本文讨论的所有方法的性能结果 -

  microbenchmark::microbenchmark(rles = dt[, val_lag1 := with(rle(val), rep(c(NA, head(values, -1)), lengths)), by = ID],
                                 chinsoon = dt[, val_lag := shift(val)[nafill(replace(seq.int(.N), rowid(rleid(val)) > 1L, NA_integer_), "locf")], by = ID],
                                 TiC = dt[, val_lag3 := c(NA,rle(val)$values)[cumsum(c(0,head(val,-1)!=tail(val,-1)))+1], by = ID],
                                 times = 1000
  )

Unit: milliseconds
     expr      min       lq     mean   median       uq      max neval cld
     rles 1.549548 1.781014 2.750187 2.096805 2.743668 46.65326  1000  a 
 chinsoon 1.766827 2.060233 3.059109 2.379477 3.077080 67.16040  1000  a 
      TiC 1.986808 2.226933 3.472451 2.624236 3.397165 60.67802  1000   b

谢谢!

【问题讨论】:

    标签: r data.table


    【解决方案1】:

    这是另一种选择:

    dt[, val_lag := shift(val)[nafill(replace(seq.int(.N), rowid(rleid(val)) > 1L, NA_integer_), "locf")]]
    

    计时码:

    library(data.table)
    set.seed(0L)
    nr <- 1e6
    ng <- 1e5
    dt = data.table(ID=sample(ng, nr, TRUE), val=as.character(sample(nr, nr, TRUE)))
    setorder(dt, ID, val)
    
    microbenchmark::microbenchmark(times = 3L,
        opt = dt[, val_lag := shift(val)[nafill(replace(seq.int(.N), rowid(rleid(val)) > 1L, NA_integer_), "locf")]],
        rle = dt[, val_lag := with(rle(val), rep(c(NA, head(values, -1)), lengths)), by = ID]
    )
        
    

    时间安排:

    Unit: milliseconds
     expr       min        lq      mean    median        uq       max neval
      opt  133.8857  159.8922  265.2029  185.8987  330.8614  475.8242     3
      rle 3097.6005 3123.5422 3193.2654 3149.4839 3241.0978 3332.7117     3
    

    编辑:添加了正在发生的事情的示例:

    index         |    1    2    3    4    5    6    7    8    9   10
    value         |    a    a    a    b    b    c    c    c    d    d
    
    shifted (s)   |   NA    a    a    a    b    b    c    c    c    d
    rowid+rleid   |    1    2    3    1    2    1    2    3    1    2
    replace       |    1   NA   NA    4   NA    6   NA   NA    9   NA <In ?nafill, Only double and integer data types are currently supported. Hence, nafill the indices before accessing>
    nafill        |    1    1    1    4    4    6    6    6    9    9
    using s above | s[1] s[1] s[1] s[4] s[4] s[6] s[6] s[6] s[9] s[9]
    

    【讨论】:

    • 这真是绝妙的解决方案! +1!
    • 我刚刚注意到您没有申请by = ID,这可能会影响所需的输出和速度。
    • 你是对的 Thomas,如果解决方案不使用 by = ID,那么 val 将越过 ID 组,值将从一组溢出到另一组。添加by=ID 使这个解决方案比rle 解决方案慢一点。
    • 我猜OP想要by = ID的结果,否则会给出不同的输出
    • shift 会将最后一个值滞后到下一组
    【解决方案2】:

    我猜你可以试试下面的代码

    dt[,val_tag := c(NA,rle(val)$values)[cumsum(c(0,head(val,-1)!=tail(val,-1)))+1],ID]
    

    你会看到

          ID val val_tag
       1:  A   a    <NA>
       2:  A   a    <NA>
       3:  A   a    <NA>
       4:  A   b       a
       5:  A   b       a
      ---
    3996:  D   g       f
    3997:  D   g       f
    3998:  D   g       f
    3999:  D   g       f
    4000:  D   g       f
    

    > tail(dt,30)
        ID val val_tag
     1:  D   d       c
     2:  D   e       d
     3:  D   e       d
     4:  D   f       e
     5:  D   f       e
     6:  D   g       f
     7:  D   g       f
     8:  D   g       f
     9:  D   g       f
    10:  D   g       f
    11:  D   a       g
    12:  D   a       g
    13:  D   a       g
    14:  D   b       a
    15:  D   b       a
    16:  D   c       b
    17:  D   c       b
    18:  D   d       c
    19:  D   d       c
    20:  D   d       c
    21:  D   d       c
    22:  D   e       d
    23:  D   e       d
    24:  D   f       e
    25:  D   f       e
    26:  D   g       f
    27:  D   g       f
    28:  D   g       f
    29:  D   g       f
    30:  D   g       f
        ID val val_tag
    

    【讨论】:

    • 谢谢!添加grpid 使代码更容易理解。
    • @Saurabh 不客气!如果你更关心速度,你可以同时测试它们(我猜第一种方法会更快)
    • 我已经尝试了这两种解决方案。第一种方法给出了部分正确的结果。在运行 tail(dt, 30) 时,有几个 NA 值,而之前的值存在。第二种方法将所有值填充为 NA。
    • @Saurabh 再次查看我的更新,但我不知道我的解决方案很快
    • @TiC - 我在问题中添加了基准测试。我认为在您的解决方案中删除函数 headtail 会使其更快。可以使用.N 代替tail
    【解决方案3】:

    我不确定在您的示例中是否需要按 ID 分组。您基本上可以查找一个移位的命名向量,这似乎更快:

    library(data.table)
    library(microbenchmark)
    ID = c(rep("A", 1000), rep("B", 1000), rep("C", 1000), rep("D", 1000))
    val = c("a", "a", "a", "b", "b", "c", "c","d","d","d","d","e","e","f","f","g","g","g","g","g")
    dt = data.table(ID, val)
    lt <- setNames(c(NaN, seq_along(unique(val))), c(NA_character_, unique(val)))
    
    microbenchmark(
        rle = dt[, val_lag := with(rle(val), rep(c(NA, head(values, -1)), lengths)), by = ID],
        TiC = dt[, val_lag := shift(unique(val))[as.integer(factor(paste(ID, val)))], ID], 
        me = dt[, val_lag := names(lt)[lt[val]]], 
        control = list(warmup=10)
    )
    #> Unit: microseconds
    #>  expr      min        lq      mean    median       uq      max neval cld
    #>   rle  614.544  653.2165  772.9975  775.7005  844.245 1391.390   100  b 
    #>   TiC 1249.129 1286.2355 1578.1695 1412.4135 1553.035 6148.756   100   c
    #>    me  330.570  346.1440  414.7982  386.9125  440.422  910.842   100 a
    
    identical(dt[, val_lag:=names(lt)[lt[val]]],  
              dt[, val_lag := with(rle(val), rep(c(NA, head(values, -1)), lengths)), by = ID])
    #> [1] TRUE
    

    reprex package (v1.0.0) 于 2021-01-31 创建

    【讨论】:

    • 在运行 tail(dt,30) 时,TiCme 解决方案中有几个 NA 值;只有rle 产生正确的结果。
    猜你喜欢
    • 1970-01-01
    • 2012-07-09
    • 2020-11-10
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2016-07-02
    相关资源
    最近更新 更多