【问题标题】:Fast way to compute this numpy query计算这个 numpy 查询的快速方法
【发布时间】:2026-01-24 20:10:02
【问题描述】:

我有一个长度为n 的布尔值numpy 数组mask。我还有一个numpy 数组a,长度n,包含从0(含)到n-1(含)的数字,并且它不包含重复项。我要计算的查询是np.array([x for x in a if mask[x]]),但我认为这不是最快的方法。

numpy 中是否有比我刚才写的更快的方法?

【问题讨论】:

    标签: numpy optimization


    【解决方案1】:

    看起来最快的方法就是a[mask[a]]。我写了一个快速测试,显示了两种方法的速度差异,具体取决于掩码的覆盖率 p(真实项目数 / n)。

    import timeit
    import matplotlib.pyplot as plt
    import numpy as np
    n = 10000
    p = 0.25
    slow_times = []
    fast_times = []
    p_space = np.linspace(0, 1, 100)
    for p in p_space:
        mask = np.random.choice([True, False], n, p=[p, 1 - p])
        a = np.arange(n)
        np.random.shuffle(a)
        y = np.array([x for x in a if mask[x]])
        z = a[mask[a]]
        n_test = 100
        t1 = timeit.timeit(lambda: np.array([x for x in a if mask[x]]), number=n_test)
        t2 = timeit.timeit(lambda: a[mask[a]], number=n_test)
        slow_times.append(t1)
        fast_times.append(t2)
    plt.plot(p_space, slow_times, label='slow')
    plt.plot(p_space, fast_times, label='fast')
    plt.xlabel('p (# true items in mask)')
    plt.ylabel('time (ms)')
    plt.legend()
    plt.title('Speed of method vs. coverage of mask')
    plt.show()
    

    这给了我这个情节

    所以无论遮罩的覆盖范围如何,这种方法都快得多。

    【讨论】: