【发布时间】:2026-01-24 20:10:02
【问题描述】:
我有一个长度为n 的布尔值numpy 数组mask。我还有一个numpy 数组a,长度n,包含从0(含)到n-1(含)的数字,并且它不包含重复项。我要计算的查询是np.array([x for x in a if mask[x]]),但我认为这不是最快的方法。
在numpy 中是否有比我刚才写的更快的方法?
【问题讨论】:
标签: numpy optimization
我有一个长度为n 的布尔值numpy 数组mask。我还有一个numpy 数组a,长度n,包含从0(含)到n-1(含)的数字,并且它不包含重复项。我要计算的查询是np.array([x for x in a if mask[x]]),但我认为这不是最快的方法。
在numpy 中是否有比我刚才写的更快的方法?
【问题讨论】:
标签: numpy optimization
看起来最快的方法就是a[mask[a]]。我写了一个快速测试,显示了两种方法的速度差异,具体取决于掩码的覆盖率 p(真实项目数 / n)。
import timeit
import matplotlib.pyplot as plt
import numpy as np
n = 10000
p = 0.25
slow_times = []
fast_times = []
p_space = np.linspace(0, 1, 100)
for p in p_space:
mask = np.random.choice([True, False], n, p=[p, 1 - p])
a = np.arange(n)
np.random.shuffle(a)
y = np.array([x for x in a if mask[x]])
z = a[mask[a]]
n_test = 100
t1 = timeit.timeit(lambda: np.array([x for x in a if mask[x]]), number=n_test)
t2 = timeit.timeit(lambda: a[mask[a]], number=n_test)
slow_times.append(t1)
fast_times.append(t2)
plt.plot(p_space, slow_times, label='slow')
plt.plot(p_space, fast_times, label='fast')
plt.xlabel('p (# true items in mask)')
plt.ylabel('time (ms)')
plt.legend()
plt.title('Speed of method vs. coverage of mask')
plt.show()
这给了我这个情节
所以无论遮罩的覆盖范围如何,这种方法都快得多。
【讨论】: