【问题标题】:R data.table: what is the fastest way to intersect a data.table by multiple columns by keys and groupsR data.table:通过键和组的多列与data.table相交的最快方法是什么
【发布时间】:2023-03-23 08:07:01
【问题描述】:

主要编辑澄清答案是错误的

我有一个包含组列 (split_by)、键列 (key_by) 和特征 ID 列 (intersect_by) 的 data.table

我希望在每个 split_by 组中,只保留组中所有当前键共享 trait id 的行。

例如:

dt <- data.table(id = 1:6, key1 = 1, key2 = c(1:2, 2), group_id1= 1, group_id2= c(1:2, 2:1, 1:2), trait_id1 = 1, trait_id2 = 2:1)
setkey(dt, group_id1, group_id2, trait_id1, trait_id2)
dt
   id key1 key2 group_id1 group_id2 trait_id1 trait_id2
1:  4    1    1         1         1         1         1
2:  1    1    1         1         1         1         2
3:  5    1    2         1         1         1         2
4:  2    1    2         1         2         1         1
5:  6    1    2         1         2         1         1
6:  3    1    2         1         2         1         2

res <- intersect_this_by(dt,
                         key_by = c("key1"),
                         split_by = c("group_id1", "group_id2"),
                         intersect_by = c("trait_id1", "trait_id2"))

我希望 res 是这样的:

> res[]
   id key1 key2 group_id1 group_id2 trait_id1 trait_id2
1:  1    1    1         1         1         1         2
2:  5    1    2         1         1         1         2
3:  2    1    2         1         2         1         1
4:  6    1    2         1         2         1         1
5:  3    1    2         1         2         1         2

我们看到 id 4 已被删除,如 group_id1 = 1 和 group_id2 = 1 组合组(id 4 所属的组)只有一个键组合 (1,1) 具有这些特征 (1,1)而该组中有两个键组合:(1,1) 和 (1,2),因此该组中的所有键不共享特征 (1,1),因此我们从该组中删除该特征,因此删除 id 4. 相反,id 1 和 5 具有相同的特征,但键不同,它们代表了该组中的所有键((1,1)和(1,2)),因此保留了 id 1 和 5 的特征。

那里给出了实现这一点的功能:

intersect_this_by2 <- function(dt,
                               key_by = NULL,
                               split_by = NULL,
                               intersect_by = NULL){

    dtc <- as.data.table(dt)       

    # compute number of keys in the group
    dtc[, n_keys := uniqueN(.SD), by = split_by, .SDcols = key_by]
    # compute number of keys represented by each trait in each group 
    # and keep row only if they represent all keys from the group
    dtc[, keep := n_keys == uniqueN(.SD), by = c(intersect_by, split_by), .SDcols = key_by]
    dtc <- dtc[keep == TRUE][, c("n_keys", "keep") := NULL]
    return(dtc)      
}

但是对于大型数据集或复杂的特征/键/组来说,它变得相当慢...... 真实的 data.table 有 1000 万行,并且特征有 30 个级别...... 有什么办法可以改善吗?有什么明显的陷阱吗? 感谢您的帮助

最终编辑: Uwe 提出了一个简洁的解决方案,它比我的初始代码快 40%(我在这里删除了它,因为它令人困惑) 最终函数如下所示:

intersect_this_by_uwe <- function(dt,
                                  key_by = c("key1"),
                                  split_by = c("group_id1", "group_id2"),
                                  intersect_by = c("trait_id1", "trait_id2")){
    dti <- copy(dt)
    dti[, original_order_id__ := 1:.N]
    setkeyv(dti, c(split_by, intersect_by, key_by))
    uni <- unique(dti, by = c(split_by, intersect_by, key_by))
    unique_keys_by_group <-
        unique(uni, by = c(split_by, key_by))[, .N, by = c(split_by)]
    unique_keys_by_group_and_trait <-
        uni[, .N, by = c(split_by, intersect_by)]
    # 1st join to pick group/traits combinations with equal number of unique keys
    selected_groups_and_traits <-
        unique_keys_by_group_and_trait[unique_keys_by_group,
                                       on = c(split_by, "N"), nomatch = 0L]
    # 2nd join to pick records of valid subsets
    dti[selected_groups_and_traits, on = c(split_by, intersect_by)][
        order(original_order_id__), -c("original_order_id__","N")]
}

对于记录,10M 行数据集的基准测试:

> microbenchmark::microbenchmark(old_way = {res <- intersect_this_by(dt,
+                                                                    key_by = c("key1"),
+                                                                    split_by = c("group_id1", "group_id2"),
+                                                                    intersect_by = c("trait_id1", "trait_id2"))},
+                                new_way = {res <- intersect_this_by2(dt,
+                                                                     key_by = c("key1"),
+                                                                     split_by = c("group_id1", "group_id2"),
+                                                                     intersect_by = c("trait_id1", "trait_id2"))},
+                                new_way_uwe = {res <- intersect_this_by_uwe(dt,
+                                                                            key_by = c("key1"),
+                                                                            split_by = c("group_id1", "group_id2"),
+                                                                            intersect_by = c("trait_id1", "trait_id2"))},
+                                times = 10)
Unit: seconds
        expr       min        lq      mean    median        uq       max neval cld
     old_way  3.145468  3.530898  3.514020  3.544661  3.577814  3.623707    10  b 
     new_way 15.670487 15.792249 15.948385 15.988003 16.097436 16.206044    10   c
 new_way_uwe  1.982503  2.350001  2.320591  2.394206  2.412751  2.436381    10 a  

【问题讨论】:

  • 特征顺序是否重要,可以满足3,2和2,3吗?还是有方向性的
  • 是的,订单很重要
  • 为什么 res 中缺少 4?它是 group = (2,1) 和 trait = (2, 1) 就像 id=40。如果是错字,那么可能是dt[id %in% dt[, .SD[, if (.N &gt; 1) id, by=.(trait_id1, trait_id2)], by=.(group_id1, group_id2)]$V1]
  • 如果顺序很重要,我希望将列粘贴在一起,然后在分组 ID 中按该列分组。然后过滤大于 1 的频率。
  • 4 也丢失了,因为它是带有键 2 的特征 2-1,但在同一组中没有具有相同特征的键 1。抱歉,我认为您错过了关键思想,但我必须承认我的解释很糟糕!

标签: r merge data.table


【解决方案1】:

有了additional explanations by the OP,相信对问题有了更深入的了解。

OP 想要从他的数据集中删除不完整的子集。每个group_id1group_id2 组都包含一组唯一的key1 值。一个完整的子集包含至少一个group_id1group_id2trait_id1trait_id2key1 记录,用于@987654335 中key1 中的每个 值@组。

在比较group_id1group_id2trait_id1trait_id2 级别上的分组时,不需要检查key1 group_id1group_id2 级别。检查不同的key1 值的数量是否相等就足够了。

因此,下面的解决方案遵循OP's own answer 的大致轮廓,但使用两个连接来实现结果:

setkey(dt, group_id1, group_id2, trait_id1, trait_id2, key1)
uni <- unique(dt, by = c("group_id1", "group_id2", "trait_id1", "trait_id2", "key1"))
unique_keys_by_group <- 
  unique(uni, by = c("group_id1", "group_id2", "key1"))[, .N, by = .(group_id1, group_id2)]
unique_keys_by_group_and_trait <- 
  uni[, .N, by = .(group_id1, group_id2, trait_id1, trait_id2)]
# 1st join to pick group/traits combinations with equal number of unique keys
selected_groups_and_traits <- 
  unique_keys_by_group_and_trait[unique_keys_by_group, 
                                 on = .(group_id1, group_id2, N), nomatch = 0L]
# 2nd join to pick records of valid subsets
res <- dt[selected_groups_and_traits, on = .(group_id1, group_id2, trait_id1, trait_id2)][
  order(id), -"N"]

可以验证结果和OP的结果一致:

identical(
  intersect_this_by(dt,
                    key_by = c("key1"),
                    split_by = c("group_id1", "group_id2"),
                    intersect_by = c("trait_id1", "trait_id2")),
  res)
[1] TRUE

请注意,uniqueN() 函数由于性能问题而使用,如the benchmarks of my first (wrong) answer 所示。

基准比较

使用了 OP 的基准数据(10 M 行)。

library(microbenchmark)
mb <- microbenchmark(
  old_way = {
    DT <- copy(dt)
    res <- intersect_this_by(DT,
                             key_by = c("key1"),
                             split_by = c("group_id1", "group_id2"),
                             intersect_by = c("trait_id1", "trait_id2"))
  },
  uwe = {
    DT <- copy(dt)
    setkey(DT, group_id1, group_id2, trait_id1, trait_id2, key1)
    uni <- 
      unique(DT, by = c("group_id1", "group_id2", "trait_id1", "trait_id2", "key1"))
    unique_keys_by_group <- 
      unique(uni, by = c("group_id1", "group_id2", "key1"))[
        , .N, by = .(group_id1, group_id2)]
    unique_keys_by_group_and_trait <- 
      uni[, .N, by = .(group_id1, group_id2, trait_id1, trait_id2)]
    selected_groups_and_traits <- 
      unique_keys_by_group_and_trait[unique_keys_by_group, 
                                     on = .(group_id1, group_id2, N), nomatch = 0L]
    res <- DT[selected_groups_and_traits, 
              on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id), -"N"]
  },
  times = 3L)
mb

此处介绍的解决方案速度提高了 40%:

Unit: seconds
    expr      min       lq     mean   median       uq      max neval cld
 old_way 7.251277 7.315796 7.350636 7.380316 7.400315 7.420315     3   b
     uwe 4.379781 4.461368 4.546267 4.542955 4.629510 4.716065     3  a

编辑:进一步的性能改进

Op has asked 提供进一步提高性能的想法。

我已经尝试过不同的方法,包括双嵌套分组(使用慢速 uniqueN() 只是为了简化代码显示):

res <- DT[, {
  nuk_g = uniqueN(key1) 
  .SD[, if(nuk_g == uniqueN(key1)) .SD, by = .(trait_id1, trait_id2)]
}, by = .(group_id1, group_id2)][order(id)]

但是对于给定的基准数据,它们都比较慢。

特定方法的性能可能不仅仅取决于问题大小,即行数,还取决于问题结构,例如, 不同组的数量、对待和键以及数据类型等。

因此,在不了解生产数据的结构和计算流程的上下文的情况下,我认为花更多时间进行基准测试是不值得的。

无论如何,有一个建议:确保只调用一次setkey(),因为它相当昂贵(大约 2 秒),但会加快所有后续操作。 (通过options(datatable.verbose = TRUE) 验证)。

【讨论】:

  • 非常感谢!它比我的第一个版本更简单、更快。在我的真实数据集上,我得到了 34% 的加速并且结果匹配。合并N是天才!请问您有什么想法可以加快速度吗? :D
  • 感谢您的更新。你知道钥匙的顺序是否重要吗? (比如组数最多的还是最少的?)
  • 可能是这样,但我对setkey()的经验有限,所以我不知道。
【解决方案2】:

编辑

虽然下面的答案确实再现了小样本数据集的预期结果,但 未能为 OP 提供的 10 M 行大数据集给出正确答案

但是,我决定保留这个错误的答案,因为基准测试结果显示 uniqueN() 函数的性能很差。此外,答案还包含更快的替代解决方案的基准。



如果我理解正确,OP 只想保留 group_id1group_id2trait_id1trait_id2 的唯一组合出现在多个不同的 key1 中的那些行。

这可以通过计算group_id1group_id2trait_id1trait_id2 的每组中key1 的唯一值并仅选择group_id1group_id2 的组合来实现、trait_id1trait_id2,其中计数大于一。最后,通过join来检索匹配的行:

library(data.table)
sel <- dt[, uniqueN(key1), by = .(group_id1, group_id2, trait_id1, trait_id2)][V1 > 1]
sel
   group_id1 group_id2 trait_id1 trait_id2 V1
1:         1         2         3         1  2
2:         2         2         2         1  2
3:         2         1         1         2  2
4:         1         1         1         1  2
5:         1         1         2         2  2
6:         2         2         2         2  2
7:         1         1         1         2  2
8:         1         1         3         2  2
res <- dt[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][order(id), -"V1"]
res
    id key1 group_id1 trait_id1 group_id2 trait_id2 extra
 1:  1    2         1         3         2         1     u
 2:  2    1         2         2         2         1     g
 3:  5    2         2         1         1         2     g
 4:  8    2         1         3         2         1     o
 5:  9    2         1         1         1         1     d
 6: 10    2         2         1         1         2     g
 7: 13    1         2         1         1         2     c
 8: 14    2         1         2         1         2     t
 9: 15    1         1         3         2         1     y
10: 16    2         1         3         2         1     v
11: 19    2         2         2         2         2     y
12: 22    2         2         2         2         1     g
13: 24    2         1         1         1         2     i
14: 25    1         1         3         1         2     n
15: 26    1         2         2         2         2     y
16: 27    1         1         1         1         1     n
17: 28    1         1         1         1         2     h
18: 29    1         2         2         2         2     b
19: 30    2         1         3         1         2     k
20: 31    1         2         2         2         2     w
21: 35    1         1         2         1         2     q
22: 37    2         2         1         1         2     r
23: 39    1         1         1         1         2     o
    id key1 group_id1 trait_id1 group_id2 trait_id2 extra

这再现了 OP 的预期结果,但它也是 OP 要求的最快方式吗?


基准测试第 1 部分

OP's code 用于创建基准数据(但使用 1 M 行而不是 10 M 行):

set.seed(0)
n <- 1e6
p <- 1e5
m <- 5
dt <- data.table(id = 1:n,
                 key1 = sample(1:m, size = n, replace = TRUE),
                 group_id1 = sample(1:2, size = n, replace = TRUE),
                 trait_id1 = sample(1:p, size = n, replace = TRUE),
                 group_id2 = sample(1:2, size = n, replace = TRUE),
                 trait_id2 = sample(1:2, size = n, replace = TRUE),
                 extra = sample(letters, n, replace = TRUE))

我很惊讶地发现使用uniqueN() 的解决方案并不是最快的:

Unit: milliseconds
    expr       min        lq      mean    median        uq       max neval cld
 old_way  489.4606  496.3801  523.3361  503.2997  540.2739  577.2482     3 a  
 new_way 9356.4131 9444.5698 9567.4035 9532.7265 9672.8987 9813.0710     3   c
    uwe1 5946.4533 5996.7388 6016.8266 6047.0243 6052.0133 6057.0023     3  b

基准代码:

microbenchmark::microbenchmark(
  old_way = {
    DT <- copy(dt)
    res <- intersect_this_by(DT,
                             key_by = c("key1"),
                             split_by = c("group_id1", "group_id2"),
                             intersect_by = c("trait_id1", "trait_id2"))
  },
  new_way = {
    DT <- copy(dt)
    res <- intersect_this_by2(DT,
                              key_by = c("key1"),
                              split_by = c("group_id1", "group_id2"),
                              intersect_by = c("trait_id1", "trait_id2"))
  },
  uwe1 = {
    DT <- copy(dt)
    sel <- DT[, uniqueN(key1), by = .(group_id1, group_id2, trait_id1, trait_id2)][V1 > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  times = 3L)

请注意,每次运行都会使用基准数据的新副本,以避免之前运行的任何副作用,例如,data.table 设置的索引。

开启详细模式

options(datatable.verbose = TRUE)

显示大部分时间都花在计算所有组的uniqueN()

sel <- DT[, uniqueN(key1), by = .(group_id1, group_id2, trait_id1, trait_id2)][V1 > 1]

Detected that j uses these columns: key1 
Finding groups using forderv ... 0.060sec 
Finding group sizes from the positions (can be avoided to save RAM) ... 0.000sec 
Getting back original order ... 0.050sec 
lapply optimization is on, j unchanged as 'uniqueN(key1)'
GForce is on, left j unchanged
Old mean optimization is on, left j unchanged.
Making each group and running j (GForce FALSE) ... 
  collecting discontiguous groups took 0.084s for 570942 groups
  eval(j) took 5.505s for 570942 calls
5.940sec

这是a known issue。但是,替代方案 lenght(unique())uniqueN() 是其缩写)仅带来 2 的中等加速。

所以我开始寻找避免uniqueN()lenght(unique())的方法。


基准测试第 2 部分

我找到了两个足够快的替代方案。两者都在第一步创建group_id1group_id2trait_id1trait_id2key1的唯一组合的data.table,计算不同key1的数量每组group_id1group_id2trait_id1trait_id2 的值,并过滤大于一的计数:

sel <- DT[, .N, by = .(group_id1, group_id2, trait_id1, trait_id2, key1)][
  , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]

sel <- unique(DT, by = c("group_id1", "group_id2", "trait_id1", "trait_id2", "key1"))[
  , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]

详细的输出表明这些变体的计算时间明显更好。

对于基准测试,只使用最快的方法,但现在有 1000 万行。此外,每个变体都尝试了setkey()setorder(),分别预先应用:

microbenchmark::microbenchmark(
  old_way = {
    DT <- copy(dt)
    res <- intersect_this_by(DT,
                             key_by = c("key1"),
                             split_by = c("group_id1", "group_id2"),
                             intersect_by = c("trait_id1", "trait_id2"))
  },
  uwe3 = {
    DT <- copy(dt)
    sel <- DT[, .N, by = .(group_id1, group_id2, trait_id1, trait_id2, key1)][
      , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  uwe3k = {
    DT <- copy(dt)
    setkey(DT, group_id1, group_id2, trait_id1, trait_id2, key1)
    sel <- DT[, .N, by = .(group_id1, group_id2, trait_id1, trait_id2, key1)][
      , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  uwe3o = {
    DT <- copy(dt)
    setorder(DT, group_id1, group_id2, trait_id1, trait_id2, key1)
    sel <- DT[, .N, by = .(group_id1, group_id2, trait_id1, trait_id2, key1)][
      , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  uwe4 = {
    DT <- copy(dt)
    sel <- unique(DT, by = c("group_id1", "group_id2", "trait_id1", "trait_id2", "key1"))[
      , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  uwe4k = {
    DT <- copy(dt)
    setkey(DT, group_id1, group_id2, trait_id1, trait_id2, key1)
    sel <- unique(DT, by = c("group_id1", "group_id2", "trait_id1", "trait_id2", "key1"))[
      , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  uwe4o = {
    DT <- copy(dt)
    setorder(DT, group_id1, group_id2, trait_id1, trait_id2, key1)
    sel <- unique(DT, by = c("group_id1", "group_id2", "trait_id1", "trait_id2", "key1"))[
      , .N, by = .(group_id1, group_id2, trait_id1, trait_id2)][N > 1]
    res <- DT[sel, on = .(group_id1, group_id2, trait_id1, trait_id2)][
      order(id)]
  },
  times = 3L)

10 M 案例的基准测试结果表明,这两种变体都比 OP 的 intersect_this_by() 函数更快,并且键控和排序正在推动加速(排序优势很小)。

Unit: seconds
    expr      min       lq     mean   median       uq      max neval  cld
 old_way 7.173517 7.198064 7.256211 7.222612 7.297559 7.372506     3    d
    uwe3 6.820324 6.833151 6.878777 6.845978 6.908003 6.970029     3   c 
   uwe3k 5.349949 5.412018 5.436806 5.474086 5.480234 5.486381     3 a   
   uwe3o 5.423440 5.432562 5.467376 5.441683 5.489344 5.537006     3 a   
    uwe4 6.270724 6.276757 6.301774 6.282790 6.317299 6.351807     3  b  
   uwe4k 5.280763 5.295251 5.418803 5.309739 5.487823 5.665906     3 a   
   uwe4o 4.921627 5.095762 5.157010 5.269898 5.274702 5.279506     3 a

【讨论】:

  • 感谢您的回答,但这仍然不是正确的答案。标准是组中的所有键都应由该特征组合表示,以保持特征组组合。请参阅下面的答案,因为我认为它更清楚。我们希望保留 group-traits 分组中 uniqueN(keys) 与 group_id 分组中 uniqueN(keys) 相同的行。
  • 我在 10M 的情况下添加了预期的行数,以便您查看
【解决方案3】:

我将从tidyverse 方法开始,并在data.table 中显示等效方法。

如果此结果不是预期的结果,请告诉我,因为它确实与您所需的输出不同 - 但它与您在文本中描述的内容不同。

1。整洁的方法

只需从特征创建单个列,然后按分组列和新的组合特征进行分组。过滤大于 1 的组频率。

dt %>%
  mutate(comb = paste0(trait_id1, trait_id2)) %>%
  group_by(group_id1, group_id2, comb) %>%
  filter(n() > 1)

2。 data.table 方法

与之前在data.table 中编写的整洁方法大致相同。

使用来自here 的答案查找快速粘贴方法。

dt[, comb := do.call(paste, c(.SD, sep = "")), .SDcols = c("trait_id1", "trait_id2")][, freq := .N, by = .(group_id1, group_id2, comb)][freq > 1]

比较

比较这两种方法,Chinsoons 评论速度是:

microbenchmark::microbenchmark(zac_tidy = {
  dt %>%
    mutate(comb = paste0(trait_id1, trait_id2)) %>%
    group_by(group_id1, group_id2, comb) %>%
    filter(n() > 1)
},
zac_dt = {
  dt[, comb := do.call(paste, c(.SD, sep = "")), .SDcols = c("trait_id1", "trait_id2")][, freq := .N, by = .(group_id1, group_id2, comb)][freq > 1]
},
chin_dt = {
  dt[id %in% dt[, .SD[, if (.N > 1) id, by=.(trait_id1, trait_id2)], by=.(group_id1, group_id2)]$V1]
}, times = 100)

Unit: milliseconds
     expr      min       lq     mean   median       uq       max neval
 zac_tidy 4.151115 4.677328 6.150869 5.552710 7.765968  8.886388   100
   zac_dt 1.965013 2.201499 2.829999 2.640686 3.507516  3.831240   100
  chin_dt 4.567210 5.217439 6.972013 7.330628 8.233379 12.807005   100

> identical(zac_dt, chin_dt)
[1] TRUE

1000 万比较

10 次重复:

Unit: milliseconds
     expr       min        lq      mean    median       uq       max neval
 zac_tidy 12.492261 14.169898 15.658218 14.680287 16.31024 22.062874    10
   zac_dt 10.169312 10.967292 12.425121 11.402416 12.23311 21.036535    10
  chin_dt  6.381693  6.793939  8.449424  8.033886  9.78187 12.005604    10
 chin_dt2  5.536246  6.888020  7.914103  8.310142  8.74655  9.600121    10

因此,我会推荐 Chinsoon 的方法。两者都可以。

【讨论】:

  • 也许时间有 1000 万行,因为那是 OP 维度。可以使用dt2[, .SD[, if (.N &gt; 1) .SD, by=.(trait_id1, trait_id2)], by=.(group_id1, group_id2)] 使 chin_dt 更快
  • 你不使用密钥?
  • 也许我们可以忘记组开始并在由 group_id 列的某些值定义的单个组中工作?
  • @BenoitLondon 好吧,鉴于您没有完全解释这一点,我并不感到惊讶。让我再看看。
  • 谢谢,非常感谢!
【解决方案4】:

其他答案并不能解决问题,但我发现了一些受其启发的方法。首先计算组中存在的键的数量,并且对于每个特征组合,只保留一个具有完整键数的组合

 intersect_this_by2 <- function(dt,
         key_by = NULL,
         split_by = NULL,
         intersect_by = NULL){

    if (is.null(intersect_by) |
        is.null(key_by) |
        !is.data.frame(dt) |
        nrow(dt) == 0) {
        return(dt)
    }
    data_table_input <- is.data.table(dt)
    dtc <- as.data.table(dt)

    if (!is.null(split_by)) {
        # compute number of keys in the group
        dtc[, n_keys := uniqueN(.SD), by = split_by, .SDcols = key_by]
        # compute number of keys represented by each trait in each group 
        # and keep row only if they represent all keys from the group
        dtc[, keep := n_keys == uniqueN(.SD), by = c(intersect_by, split_by), .SDcols = key_by]
        dtc <- dtc[keep == TRUE][, c("n_keys", "keep") := NULL]
    } else {
        dtc[, n_keys := uniqueN(.SD), .SDcols = key_by]
        dtc[, keep := n_keys == uniqueN(.SD), by = c(intersect_by), .SDcols = key_by]
        dtc <- dtc[keep == TRUE][, c("n_keys", "keep") := NULL]
    }
    if (!data_table_input) {
        return(as.data.frame(dtc))
    } else {
        return(dtc)
    }
}

问题是它在我的真实数据集上要慢得多(慢 5-6 倍),但我认为这个函数有助于更好地理解问题。下面还定义了一个更接近我真实数据的数据集:

pacman::p_load(data.table, microbenchmark, testthat)

set.seed(0)
n <- 1e7
p <- 1e5
m <- 5
dt <- data.table(id = 1:n,
                 key1 = sample(1:m, size = n, replace = TRUE),
                 group_id1 = sample(1:2, size = n, replace = TRUE),
                 trait_id1 = sample(1:p, size = n, replace = TRUE),
                 group_id2 = sample(1:2, size = n, replace = TRUE),
                 trait_id2 = sample(1:2, size = n, replace = TRUE),
                 extra = sample(letters, n, replace = TRUE))
microbenchmark::microbenchmark(old_way = {res <- intersect_this_by(dt,
                                                                    key_by = c("key1"),
                                                                    split_by = c("group_id1", "group_id2"),
                                                                    intersect_by = c("trait_id1", "trait_id2"))},
                               new_way = {res <- intersect_this_by2(dt,
                                                                   key_by = c("key1"),
                                                                   split_by = c("group_id1", "group_id2"),
                                                                   intersect_by = c("trait_id1", "trait_id2"))},
                               times = 1)


Unit: seconds
    expr       min        lq      mean    median        uq       max neval
 old_way  5.891489  5.891489  5.891489  5.891489  5.891489  5.891489     1
 new_way 18.455860 18.455860 18.455860 18.455860 18.455860 18.455860     1

有关信息,此示例中 res 的行数是

> set.seed(0)
> n <- 1e7
> p <- 1e5
> m <- 5
> dt <- data.table(id = 1:n,
                   key1 = sample(1:m, size = n, replace = TRUE),
                   group_id1 = sample(1:2, size = n, replace = TRUE),
                   trait_id1 = sample(1:p, size = n, replace = TRUE),
                   group_id2 = sample(1:2, size = n, replace = TRUE),
                   trait_id2 = sample(1:2, size = n, replace = TRUE),
                   extra = sample(letters, n, replace = TRUE))
> res <- intersect_this_by(dt,
                            key_by = c("key1"),
                            split_by = c("group_id1", "group_id2"),
                            intersect_by = c("trait_id1", "trait_id2"))
> nrow(res)
[1] 7099860
> res <- intersect_this_by2(dt,
                            key_by = c("key1"),
                            split_by = c("group_id1", "group_id2"),
                            intersect_by = c("trait_id1", "trait_id2"))
> nrow(res)
[1] 7099860

【讨论】:

  • 缓慢的部分似乎在第二个 uniqueN 中,实际上我们只需要检查 n_keys
猜你喜欢
  • 2017-03-23
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2018-08-25
  • 1970-01-01
  • 1970-01-01
  • 2014-09-11
  • 1970-01-01
相关资源
最近更新 更多