【问题标题】:Fast combinations without replacement for arrays - NumPy / Python无需替换数组的快速组合 - NumPy / Python
【发布时间】:2019-12-25 03:41:05
【问题描述】:

我正在从一维数组高效地生成成对组合。如果 n > 1000

,Itertools 的效率太低了
E.g. [1, 2, 3, 4]

magic code...

Out[2]:
array([[1, 2],
       [1, 3],
       [1, 4],
       [2, 3],
       [2, 4],
       [3, 4]])

最接近它的是here

【问题讨论】:

  • divakar 的回答指出了这个问题的解决方案(stackoverflow.com/questions/11144513/…),我想补充一点,这个问题和其他答案也很有用。
  • @ItamarMushkin 事实证明,我们可以做得更好 :) 删除了该解决方案。
  • 哦。认为这个问题仍然值得一游?
  • @ItamarMushkin 还是值得的。问答中有很多有用的想法。

标签: python performance numpy combinations


【解决方案1】:

我。成对组合

一种方法是使用 numba 来获得内存并因此提高性能 -

from numba import njit

@njit
def pairwise_combs_numba(a):
    n = len(a)
    L = n*(n-1)//2
    out = np.empty((L,2),dtype=a.dtype)
    iterID = 0
    for i in range(n):
        for j in range(i+1,n):
            out[iterID,0] = a[i]
            out[iterID,1] = a[j]
            iterID += 1
    return out

另一个基于 NumPy 的方法是使用 np.broadcast_to 获取网格视图,然后进行屏蔽 -

def pairwise_combs_mask(a):
    n = len(a)
    L = n*(n-1)//2
    out = np.empty((L,2),dtype=a.dtype)
    m = ~np.tri(len(a),dtype=bool)
    out[:,0] = np.broadcast_to(a[:,None],(n,n))[m]
    out[:,1] = np.broadcast_to(a,(n,n))[m]
    return out

二。三重组合

我们将扩展相同的方法来获得三重组合 -

@njit
def triplet_combs_numba(a):
    n = len(a)
    L = n*(n-1)*(n-2)//6
    out = np.empty((L,3),dtype=a.dtype)
    iterID = 0
    for i in range(n):
        for j in range(i+1,n):
            for k in range(j+1,n):
                out[iterID,0] = a[i]
                out[iterID,1] = a[j]
                out[iterID,2] = a[k]
                iterID += 1
    return out

def triplet_combs_mask(a):
    n = len(a)
    L = n*(n-1)*(n-2)//6
    out = np.empty((L,3),dtype=a.dtype)
    r = np.arange(n)
    m = (r[:,None,None]<r[:,None]) & (r[:,None]<r)
    out[:,0] = np.broadcast_to(a[:,None,None],(n,n,n))[m]
    out[:,1] = np.broadcast_to(a[None,:,None],(n,n,n))[m]
    out[:,2] = np.broadcast_to(a[None,None,:],(n,n,n))[m]
    return out

更高阶的组合也会同样扩展。

示例运行 -

In [54]: a = np.array([3,9,4,1,7])

In [55]: pairwise_combs_numba(a)
Out[55]: 
array([[3, 9],
       [3, 4],
       [3, 1],
       [3, 7],
       [9, 4],
       [9, 1],
       [9, 7],
       [4, 1],
       [4, 7],
       [1, 7]])

In [56]: triplet_combs_numba(a)
Out[56]: 
array([[3, 9, 4],
       [3, 9, 1],
       [3, 9, 7],
       [3, 4, 1],
       [3, 4, 7],
       [3, 1, 7],
       [9, 4, 1],
       [9, 4, 7],
       [9, 1, 7],
       [4, 1, 7]])

计时(包括 Python 的内置函数 - itertools.combinations) -

In [68]: a = np.random.rand(4000)

In [69]: %timeit pairwise_combs_numba(a)
    ...: %timeit pairwise_combs_mask(a)
    ...: %timeit list(itertools.combinations(a, 2))
10 loops, best of 3: 52.2 ms per loop
10 loops, best of 3: 146 ms per loop
1 loop, best of 3: 597 ms per loop

In [70]: a = np.random.rand(400)

In [71]: %timeit triplet_combs_numba(a)
    ...: %timeit triplet_combs_mask(a)
    ...: %timeit list(itertools.combinations(a, 3))
10 loops, best of 3: 98.5 ms per loop
1 loop, best of 3: 352 ms per loop
1 loop, best of 3: 795 ms per loop

【讨论】:

    猜你喜欢
    • 2012-11-15
    • 2011-03-25
    • 2011-12-21
    • 2016-12-29
    • 1970-01-01
    • 1970-01-01
    • 2015-12-12
    • 2010-12-25
    相关资源
    最近更新 更多