【问题标题】:Faster weighted sampling without replacement无需更换即可实现更快的加权采样
【发布时间】:2013-02-13 08:27:03
【问题描述】:

这个问题引出了一个新的 R 包: wrswoR

R 的默认采样不使用 sample.int 进行替换似乎需要二次运行时间,例如当使用从均匀分布中提取的权重时。对于大样本量,这很慢。有谁知道一个更快的实现可以在 R 中使用?两个选项是“带替换的拒绝抽样”(参见 stats.sx 上的 this question)和 Wong and Easton (1980) 的算法(在 StackOverflow answer 中使用 Python 实现)。

感谢 Ben Bolker 提示当使用 replace=F 和非均匀权重调用 sample.int 时内部调用的 C 函数:ProbSampleNoReplace。实际上,代码显示了两个嵌套的 for 循环(random.c 的第 420 行 ff)。

这是根据经验分析运行时间的代码:

library(plyr)

sample.int.test <- function(n, p) {
    sample.int(2 * n, n, replace=F, prob=p); NULL }

times <- ldply(
  1:7,
  function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n)
    data.frame(
      n=n,
      user=system.time(sample.int.test(n, p), gcFirst=T)['user.self'])
  },
  .progress='text'
)

times

library(ggplot2)
ggplot(times, aes(x=n, y=user/n)) + geom_point() + scale_x_log10() +
  ylab('Time per unit (s)')

# Output:
       n   user
1   2048  0.008
2   4096  0.028
3   8192  0.100
4  16384  0.408
5  32768  1.645
6  65536  6.604
7 131072 26.558

编辑:感谢 Arun 指出未加权采样似乎没有这种性能损失。

【问题讨论】:

  • 我不明白为什么 runif() 会使运行时 二次 ...
  • 没有runif(2*n) 的相同语句运行时间为 0.001 秒。
  • 看来sampleprob 比较耗时。
  • 所使用的算法可在 github.com/wch/r-source/blob/trunk/src/main/random.c 查看:搜索 ProbSampleReplace。我不知道这是否有用,但它应该让您大致了解所使用的算法以及它是否可以轻松改进。我注意到它正在对整个向量进行排序...
  • 我唯一可以建议的另一件事是您尝试library("sos"): findFn("{sampling without replacement}") 并整理结果以查看是否有任何有用的信息

标签: r performance algorithm


【解决方案1】:

更新:

Efraimidis & Spirakis 算法的Rcpp 实现(感谢@Hemmo、@Dinrem、@krlmlr 和@rtlgrmpf):

library(inline)
library(Rcpp)
src <- 
'
int num = as<int>(size), x = as<int>(n);
Rcpp::NumericVector vx = Rcpp::clone<Rcpp::NumericVector>(x);
Rcpp::NumericVector pr = Rcpp::clone<Rcpp::NumericVector>(prob);
Rcpp::NumericVector rnd = rexp(x) / pr;
for(int i= 0; i<vx.size(); ++i) vx[i] = i;
std::partial_sort(vx.begin(), vx.begin() + num, vx.end(), Comp(rnd));
vx = vx[seq(0, num - 1)] + 1;
return vx;
'
incl <- 
'
struct Comp{
  Comp(const Rcpp::NumericVector& v ) : _v(v) {}
  bool operator ()(int a, int b) { return _v[a] < _v[b]; }
  const Rcpp::NumericVector& _v;
};
'
funFast <- cxxfunction(signature(n = "Numeric", size = "integer", prob = "numeric"),
                       src, plugin = "Rcpp", include = incl)

# See the bottom of the answer for comparison
p <- c(995/1000, rep(1/1000, 5))
n <- 100000
system.time(print(table(replicate(funFast(6, 3, p), n = n)) / n))

      1       2       3       4       5       6 
1.00000 0.39996 0.39969 0.39973 0.40180 0.39882 
   user  system elapsed 
   3.93    0.00    3.96 
# In case of:
# Rcpp::IntegerVector vx = Rcpp::clone<Rcpp::IntegerVector>(x);
# i.e. instead of NumericVector
      1       2       3       4       5       6 
1.00000 0.40150 0.39888 0.39925 0.40057 0.39980 
   user  system elapsed 
   1.93    0.00    2.03 

旧版本:

让我们尝试几种可能的方法:

带替换的简单拒绝抽样。这是一个比@krlmlr 提供的sample.int.rej 简单得多的功能,即样本大小始终等于n。正如我们将看到的,假设权重均匀分布,它仍然非常快,但在另一种情况下非常慢。

fastSampleReject <- function(all, n, w){
  out <- numeric(0)
  while(length(out) < n)
    out <- unique(c(out, sample(all, n, replace = TRUE, prob = w)))
  out[1:n]
}

Wong 和 Easton (1980) 的算法。这是this Python 版本的实现。它很稳定,我可能会遗漏一些东西,但与其他功能相比它要慢得多。

fastSample1980 <- function(all, n, w){
  tws <- w
  for(i in (length(tws) - 1):0)
    tws[1 + i] <- sum(tws[1 + i], tws[1 + 2 * i + 1], 
                      tws[1 + 2 * i + 2], na.rm = TRUE)      
  out <- numeric(n)
  for(i in 1:n){
    gas <- tws[1] * runif(1)
    k <- 0        
    while(gas > w[1 + k]){
      gas <- gas - w[1 + k]
      k <- 2 * k + 1
      if(gas > tws[1 + k]){
        gas <- gas - tws[1 + k]
        k <- k + 1
      }
    }
    wgh <- w[1 + k]
    out[i] <- all[1 + k]        
    w[1 + k] <- 0
    while(1 + k >= 1){
      tws[1 + k] <- tws[1 + k] - wgh
      k <- floor((k - 1) / 2)
    }
  }
  out
}

Wong 和 Easton 对算法的 Rcpp 实现。可能它可以进一步优化,因为这是我第一个可用的 Rcpp 函数,但无论如何它运行良好。

library(inline)
library(Rcpp)

src <-
'
Rcpp::NumericVector weights = Rcpp::clone<Rcpp::NumericVector>(w);
Rcpp::NumericVector tws = Rcpp::clone<Rcpp::NumericVector>(w);
Rcpp::NumericVector x = Rcpp::NumericVector(all);
int k, num = as<int>(n);
Rcpp::NumericVector out(num);
double gas, wgh;

if((weights.size() - 1) % 2 == 0){
  tws[((weights.size()-1)/2)] += tws[weights.size()-1] + tws[weights.size()-2];
}
else
{
  tws[floor((weights.size() - 1)/2)] += tws[weights.size() - 1];
}

for (int i = (floor((weights.size() - 1)/2) - 1); i >= 0; i--){
  tws[i] += (tws[2 * i + 1]) + (tws[2 * i + 2]);
}
for(int i = 0; i < num; i++){
  gas = as<double>(runif(1)) * tws[0];
  k = 0;
  while(gas > weights[k]){
    gas -= weights[k];
    k = 2 * k + 1;
    if(gas > tws[k]){
      gas -= tws[k];
      k += 1;
    }
  }
  wgh = weights[k];
  out[i] = x[k];
  weights[k] = 0;
  while(k > 0){
    tws[k] -= wgh;
    k = floor((k - 1) / 2);
  }
  tws[0] -= wgh;
}
return out;
'

fun <- cxxfunction(signature(all = "numeric", n = "integer", w = "numeric"),
                   src, plugin = "Rcpp")

现在有一些结果:

times1 <- ldply(
  1:6,
  function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n) # Uniform distribution
    p <- p/sum(p)
    data.frame(
      n=n,
      user=c(system.time(sample.int.test(n, p), gcFirst=T)['user.self'],
             system.time(weighted_Random_Sample(1:(2*n), p, n), gcFirst=T)['user.self'],
             system.time(fun(1:(2*n), n, p), gcFirst=T)['user.self'],
             system.time(sample.int.rej(2*n, n, p), gcFirst=T)['user.self'],
             system.time(fastSampleReject(1:(2*n), n, p), gcFirst=T)['user.self'],
             system.time(fastSample1980(1:(2*n), n, p), gcFirst=T)['user.self']),
      id=c("Base", "Reservoir", "Rcpp", "Rejection", "Rejection simple", "1980"))
  },
  .progress='text'
)


times2 <- ldply(
  1:6,
  function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n - 1)
    p <- p/sum(p) 
    p <- c(0.999, 0.001 * p) # Special case
    data.frame(
      n=n,
      user=c(system.time(sample.int.test(n, p), gcFirst=T)['user.self'],
             system.time(weighted_Random_Sample(1:(2*n), p, n), gcFirst=T)['user.self'],
             system.time(fun(1:(2*n), n, p), gcFirst=T)['user.self'],
             system.time(sample.int.rej(2*n, n, p), gcFirst=T)['user.self'],
             system.time(fastSampleReject(1:(2*n), n, p), gcFirst=T)['user.self'],
             system.time(fastSample1980(1:(2*n), n, p), gcFirst=T)['user.self']),
      id=c("Base", "Reservoir", "Rcpp", "Rejection", "Rejection simple", "1980"))
  },
  .progress='text'
)

arrange(times1, id)
       n  user               id
1   2048  0.53             1980
2   4096  0.94             1980
3   8192  2.00             1980
4  16384  4.32             1980
5  32768  9.10             1980
6  65536 21.32             1980
7   2048  0.02             Base
8   4096  0.05             Base
9   8192  0.18             Base
10 16384  0.75             Base
11 32768  2.99             Base
12 65536 12.23             Base
13  2048  0.00             Rcpp
14  4096  0.01             Rcpp
15  8192  0.03             Rcpp
16 16384  0.07             Rcpp
17 32768  0.14             Rcpp
18 65536  0.31             Rcpp
19  2048  0.00        Rejection
20  4096  0.00        Rejection
21  8192  0.00        Rejection
22 16384  0.02        Rejection
23 32768  0.02        Rejection
24 65536  0.03        Rejection
25  2048  0.00 Rejection simple
26  4096  0.01 Rejection simple
27  8192  0.00 Rejection simple
28 16384  0.01 Rejection simple
29 32768  0.00 Rejection simple
30 65536  0.05 Rejection simple
31  2048  0.00        Reservoir
32  4096  0.00        Reservoir
33  8192  0.00        Reservoir
34 16384  0.02        Reservoir
35 32768  0.03        Reservoir
36 65536  0.05        Reservoir

arrange(times2, id)
       n  user               id
1   2048  0.43             1980
2   4096  0.93             1980
3   8192  2.00             1980
4  16384  4.36             1980
5  32768  9.08             1980
6  65536 19.34             1980
7   2048  0.01             Base
8   4096  0.04             Base
9   8192  0.18             Base
10 16384  0.75             Base
11 32768  3.11             Base
12 65536 12.04             Base
13  2048  0.01             Rcpp
14  4096  0.02             Rcpp
15  8192  0.03             Rcpp
16 16384  0.08             Rcpp
17 32768  0.15             Rcpp
18 65536  0.33             Rcpp
19  2048  0.00        Rejection
20  4096  0.00        Rejection
21  8192  0.02        Rejection
22 16384  0.02        Rejection
23 32768  0.05        Rejection
24 65536  0.08        Rejection
25  2048  1.43 Rejection simple
26  4096  2.87 Rejection simple
27  8192  6.17 Rejection simple
28 16384 13.68 Rejection simple
29 32768 29.74 Rejection simple
30 65536 73.32 Rejection simple
31  2048  0.00        Reservoir
32  4096  0.00        Reservoir
33  8192  0.02        Reservoir
34 16384  0.02        Reservoir
35 32768  0.02        Reservoir
36 65536  0.04        Reservoir

显然我们可以拒绝函数1980,因为在这两种情况下它都比Base 慢。 Rejection simple 在第二种情况下只有一个概率为 0.999 时也会遇到麻烦。

所以还有RejectionRcppReservoir。最后一步是检查值本身是否正确。为确定它们,我们将使用sample 作为基准(同时消除由于无替换抽样而不必与p 重合的概率的混淆)。

p <- c(995/1000, rep(1/1000, 5))
n <- 100000

system.time(print(table(replicate(sample(1:6, 3, repl = FALSE, prob = p), n = n))/n))
      1       2       3       4       5       6 
1.00000 0.39992 0.39886 0.40088 0.39711 0.40323  # Benchmark
   user  system elapsed 
   1.90    0.00    2.03 

system.time(print(table(replicate(sample.int.rej(2*3, 3, p), n = n))/n))
      1       2       3       4       5       6 
1.00000 0.40007 0.40099 0.39962 0.40153 0.39779 
   user  system elapsed 
  76.02    0.03   77.49 # Slow

system.time(print(table(replicate(weighted_Random_Sample(1:6, p, 3), n = n))/n))
      1       2       3       4       5       6 
1.00000 0.49535 0.41484 0.36432 0.36338 0.36211  # Incorrect
   user  system elapsed 
   3.64    0.01    3.67 

system.time(print(table(replicate(fun(1:6, 3, p), n = n))/n))
      1       2       3       4       5       6 
1.00000 0.39876 0.40031 0.40219 0.40039 0.39835 
   user  system elapsed 
   4.41    0.02    4.47 

请注意这里的一些事情。出于某种原因,weighted_Random_Sample 返回了不正确的值(我根本没有研究过它,但假设分布均匀,它可以正常工作)。 sample.int.rej 重复采样很慢。

总之,Rcpp 似乎是重复采样情况下的最佳选择,而sample.int.rej 在其他情况下速度更快,也更易于使用。

【讨论】:

  • 非常好,尤其是测试采样器的代码! weighted_Random_Sample 可能受到 IEEE 浮点值精度有限的影响。不幸的是,Dinre 尚未回复我关于该主题的请求,但请参阅 math.sx 上的 this related question :-)
  • 请检查我的包裹中的sample.int.rank,文件sample_int_rank.R。此函数是通过您的分布测试的水库采样实现。
  • @krlmlr,我明白了,现在它是算法优雅竞赛的有效获胜者 :) 但是,重复采样很慢。今天晚些时候我会更新一点我的答案。
  • 只是为了在这个过程中多加一把扳手,在我的机器上,我在基准 1.0000 0.4017 0.4002 0.4091 0.3917 0.3973 和 weighted_Random_Sample 1.0000 0.3991 0.4056 0.4012 0.4014 0.3927 之间得到了几乎相同的结果。不过,我使用的是p &lt;- c(955, rep(1, 5)),因为这实际上是一种更可靠的采样方法。所有主要的采样函数都接受权重并且不假设sum(prob) == 1,所以我更喜欢在我的所有工作中使用权重。不过,这可能会限制 WRS 算法的使用方式。
  • @Dinre:这表明您的实现取决于权重的大小——使用-rexp(n) / prob = log(runif(n)) / prob 在数值上比runif(n) ^ (1 / prob) 更健壮,并且等效于w.r.t。排序顺序。
【解决方案2】:

让我介绍一下我自己的基于rejection sampling with replacement 的更快方法的实现。思路是这样的:

  • 生成一个样本替换“有点”大于请求的大小

  • 丢弃重复值

  • 如果没有绘制足够的值,则使用调整后的nsizeprob 参数递归调用相同的过程

  • 将返回的索引重新映射到原始索引

我们需要抽取多大的样本?假设均匀分布,结果是expected number of trials to see x unique values out of N total values。这是两个harmonic numbers(H_n 和 H_{n - size})的差异。前几个谐波数列在表格中,否则使用使用自然对数的近似值。 (这只是一个大概的数字,这里不需要太精确。)现在,对于非均匀分布,预期要抽取的项目数只能更大,所以我们不会抽取太多样本。此外,抽取的样本数量受到人口规模两倍的限制——我假设进行一些递归调用比抽样多达 O(n ln n) 个项目要快。

代码可在sample_int_rej.Rsample.int.rej 例程中的R 包wrswoR 中获得。安装:

library(devtools)
install_github('wrswoR', 'muelleki')

它的工作似乎“足够快”,但尚未进行正式的运行时测试。此外,该软件包仅在 Ubuntu 中进行了测试。感谢您的反馈。

【讨论】:

  • 是的,假设均匀分布它很快,但假设不是那么方便,它变得非常糟糕。今天我将发布一个关于它的答案,this 的 R 实现,我将尝试完成它的Rcpp 版本。
  • @Julius:期待基准测试 :-) 我已经用权重的指数分布测试了我的代码(这是使用 IEEE 浮点数可以得到的最糟糕的结果),期待真正可怕的行为,但对我来说惊讶它并没有那么糟糕......
  • @krlmlr,按照您描述算法的方式,包含概率将与使用的权重成正比。
  • @Ferdinand.kraft:奥利?需要详细说明吗?
  • @krlmlr 在拒绝抽样中,如果在替换步骤中有重复,则必须丢弃整个样本。即使这样,包含概率也以复杂的非线性方式与权重相关。在基于设计的推理中,我们需要最终的包含概率,而不是算法内部使用的权重。所以我的观点是:如何根据您提出的方法抽取样本的重量和大小来计算它们?
【解决方案3】:

我决定深入研究一些 cmets,发现 Efraimidis & Spirakis 论文很吸引人(感谢 @Hemmo 找到参考)。论文中的总体思路是这样的:通过生成一个随机统一数并将其提高到每个项目的权重的 1 次方来创建一个密钥。然后,您只需将最高的键值作为样本。效果非常好!

weighted_Random_Sample <- function(
    .data,
    .weights,
    .n
    ){

    key <- runif(length(.data)) ^ (1 / .weights)
    return(.data[order(key, decreasing=TRUE)][1:.n])
}

如果您将 '.n' 设置为 '.data' 的长度(应该始终是 '.weights' 的长度),这实际上是一个加权的储层置换,但该方法适用于采样和排列。

更新:我应该提一下,上面的函数期望权重大于零。否则key &lt;- runif(length(.data)) ^ (1 / .weights) 将无法正确订购。


只是为了好玩,我还使用了 OP 中的测试场景来比较这两个功能。

set.seed(1)

times_WRS <- ldply(
1:7,
function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n)
    n_Set <- 1:(2 * n)
    data.frame(
      n=n,
      user=system.time(weighted_Random_Sample(n_Set, p, n), gcFirst=T)['user.self'])
  },
  .progress='text'
)

sample.int.test <- function(n, p) {
sample.int(2 * n, n, replace=F, prob=p); NULL }

times_sample.int <- ldply(
  1:7,
  function(i) {
    n <- 1024 * (2 ** i)
    p <- runif(2 * n)
    data.frame(
      n=n,
      user=system.time(sample.int.test(n, p), gcFirst=T)['user.self'])
  },
  .progress='text'
)

times_WRS$group <- "WRS"
times_sample.int$group <- "sample.int"
library(ggplot2)

ggplot(rbind(times_WRS, times_sample.int) , aes(x=n, y=user/n, col=group)) + geom_point() + scale_x_log10() +  ylab('Time per unit (s)')

以下是时间:

times_WRS
#        n user
# 1   2048 0.00
# 2   4096 0.01
# 3   8192 0.00
# 4  16384 0.01
# 5  32768 0.03
# 6  65536 0.06
# 7 131072 0.16

times_sample.int
#        n  user
# 1   2048  0.02
# 2   4096  0.05
# 3   8192  0.14
# 4  16384  0.58
# 5  32768  2.33
# 6  65536  9.23
# 7 131072 37.79

【讨论】:

  • 这真的和没有替换的加权抽样一样,还是只是一个近似值?我在math.sx 上问过这个问题,但没有得到答案...
  • 嗯,所有形式的随机抽样都是近似值,所以我想这两个问题的答案都是“是”。更具体地说,这种方法确实满足了我熟悉的所有抽样标准,因为在这种情况下,2 对 1 的权重被选择为序列中的第一个元素的可能性是其两倍。需要注意的是,在不可替换的情况下,唯一的问题是在哪里将放置一个项目,而不是如果
  • ...除非您只想对部分人口进行抽样...?
  • 没有。这是同一件事,因为您没有使用替换。您仍然订购整套,但您从订购的集合中选择 1:n 数量的项目。
  • Pavlos S. Efraimidis 和 Paul G. Spirakis 的算法是迄今为止我见过的最漂亮的东西,只是因为它很简单。它就像通过 FFT 实现卷积一样甜蜜,但不确定哪个会胜出…… 注意:作者证明他们的算法等效于加权随机采样而无需替换。
猜你喜欢
  • 2012-01-02
  • 1970-01-01
  • 1970-01-01
  • 2013-10-05
  • 1970-01-01
  • 1970-01-01
  • 2019-05-04
  • 1970-01-01
  • 2010-10-15
相关资源
最近更新 更多