【问题标题】:Numpy: efficient way to generate combinations from given rangesNumpy:从给定范围生成组合的有效方法
【发布时间】:2014-12-04 04:53:51
【问题描述】:

我有一个如下所示的n维数组:

np.array([[0,3],[0,3],[0,10]])

在此数组中,元素表示低值和高值。例如:[0,3] 指的是[0,1,2,3]

我需要使用上面给出的范围生成所有值的组合。 比如我要[0,0,0], [0,0,1] ... [0,1,0] ... [3,3,10]

我已经尝试了以下方法来得到我想要的:

ds = np.array([[0,3],[0,3],[0,10]])
nItems = int(reduce(lambda a,b: a * (b[1] - b[0] + 1), ds, 1))
myCombinations = np.zeros((nItems,))
nArrays = []
for x in range(ds.shape[0]):
    low = ds[x][0]
    high= ds[x][1]
    nitm = high - low + 1
    ar = [x+low for x in range(nitm) ]
    nArrays.append(ar)

myCombinations = cartesian(nArrays)

笛卡尔函数取自Using numpy to build an array of all combinations of two arrays

我需要这样做几百万次

我的问题:有没有更好/更有效的方法来做到这一点?

【问题讨论】:

    标签: python arrays numpy combinations


    【解决方案1】:

    我认为您正在寻找的是np.mgrid。不幸的是,这会以与您需要的格式不同的格式返回数组,因此您需要进行一些后期处理:

    a = np.mgrid[0:4, 0:4, 0:11]     # All points in a 3D grid within the given ranges
    a = np.rollaxis(a, 0, 4)         # Make the 0th axis into the last axis
    a = a.reshape((4 * 4 * 11, 3))   # Now you can safely reshape while preserving order
    

    说明

    np.mgrid 为您提供 N 维空间中的一组网格点。让我尝试用一​​个更小的例子来说明这一点,以使事情更清楚:

    >>> a = np.mgrid[0:2, 0:2]
    >>> a
    array([[[0, 0],
            [1, 1]],
    
           [[0, 1],
            [0, 1]]])
    

    由于我给出了两组范围0:2, 0:2,因此我得到了一个二维网格。 mgrid 返回的是二维空间中网格点 (0, 0)、(0, 1)、(1, 0) 和 (1, 1) 对应的 x 值和 y 值。 a[0] 告诉你四个点的 x 值是多少,a[1] 告诉你 y 值是多少。

    但您真正想要的是我写出的实际网格点列表,而不是这些点的 x 值和 y 值。第一直觉是根据需要重塑数组:

    >>> a.reshape((4, 2))
    array([[0, 0],
           [1, 1],
           [0, 1],
           [0, 1]])
    

    但显然这不起作用,因为它有效地重塑了扁平数组(通过按顺序读取所有元素获得的数组),而这不是你想要的。

    你要做的是向下看a第三个​​维度,并创建一个数组:

    [ [a[0][0, 0], a[1][0, 0]],
      [a[0][0, 1], a[1][0, 1]],
      [a[0][1, 0], a[1][1, 0]],
      [a[0][1, 1], a[1][1, 1]] ]
    

    上面写着“首先告诉我第一个点 (x1, y1),然后是第二个点 (x2, y2), ...”等等。也许用某种数字来更好地解释这一点。这就是a 的样子:

                    you want to read
                    in this direction
                     (0, 0)   (0, 1)
                       |        |
                       |        |
                       v        v
    
              /        0--------0            +----> axis0
     x-values |       /|       /|           /|
              |      / |      / |    axis1 / |
              \     1--------1  |         L  |
                    |  |     |  |            v
              /     |  0-----|--1           axis2
     y-values |     | /      | /
              |     |/       |/
              \     0--------1
    
                    |        |
                    |        |
                    v        v
                  (1, 0)   (1, 1)
    

    np.rollaxis 为您提供了一种方法来做到这一点。 np.rollaxis(a, 0, 3) 在上面的例子中说“取第 0 个(或最外层)轴并使其成为最后一个(或最内层)轴。(注意:只有轴 0,这里实际上存在 1 和 2。所以说“将第 0 轴发送到第 3 位”是告诉 python 将第 0 轴放在最后一个轴之后的一种方式。您可能还想阅读this

    >>> a = np.rollaxis(a, 0, 3)
    >>> a
    array([[[0, 0],
            [0, 1]],
    
           [[1, 0],
            [1, 1]]])
    

    这开始看起来像你想要的,除了有一个额外的数组维度。我们想要合并维度 0 和 1 以获得一个网格点数组。但是现在扁平化数组以您期望的方式读取,您可以安全地对其进行整形以获得所需的结果。

    >>> a = a.reshape((4, 2))
    >>> a
    array([[0, 0],
           [0, 1],
           [1, 0],
           [1, 1]])
    

    3D 版本做同样的事情,除了,我无法为它制作一个数字,因为它是 4D 的。

    【讨论】:

    • 这非常有效(100000 次运行大约需要 4 秒),但它相当混乱,您能解释一下它是如何工作的吗? (或者请指向我可以理解的一些文档?)
    • 为了您的利益,我添加了一个解释,但在我的计算机上,itertools.product 实际上运行速度大约快了 6 倍。我的方法中的大部分时间都被mgrid 本身所消耗,因此您甚至无法通过避免rollaxisreshape 来摆脱它。出于好奇,您使用的是什么版本的 Python 和 numpy?
    • 我刚刚意识到实现rollaxis+reshape 效果的另一种方法是使用zip(a[0].flatten(), a[1].flatten(), a[2].flatten())
    • 哇!感谢您的解释!我正在运行 python 2.7.6 和 numpy 1.8.1,我再次检查,结果在我的机器上是相似的。 itertools 需要更长的时间!
    【解决方案2】:

    你可以使用itertools.product:

    In [16]: from itertools import product
    
    In [17]: values = list(product(range(4), range(4), range(11)))
    
    In [18]: values[:5]
    Out[18]: [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), (0, 0, 4)]
    
    In [19]: values[-5:]
    Out[19]: [(3, 3, 6), (3, 3, 7), (3, 3, 8), (3, 3, 9), (3, 3, 10)]
    

    给定范围数组,您可以执行以下操作。 (我使用了几个非零的低值来演示一般情况 - 并减少输出的大小。:)

    In [41]: ranges = np.array([[0, 3], [1, 3], [8, 10]])
    
    In [42]: list(product(*(range(lo, hi+1) for lo, hi in ranges)))
    Out[42]: 
    [(0, 1, 8),
     (0, 1, 9),
     (0, 1, 10),
     (0, 2, 8),
     (0, 2, 9),
     (0, 2, 10),
     (0, 3, 8),
     (0, 3, 9),
     (0, 3, 10),
     (1, 1, 8),
     (1, 1, 9),
     (1, 1, 10),
     (1, 2, 8),
     (1, 2, 9),
     (1, 2, 10),
     (1, 3, 8),
     (1, 3, 9),
     (1, 3, 10),
     (2, 1, 8),
     (2, 1, 9),
     (2, 1, 10),
     (2, 2, 8),
     (2, 2, 9),
     (2, 2, 10),
     (2, 3, 8),
     (2, 3, 9),
     (2, 3, 10),
     (3, 1, 8),
     (3, 1, 9),
     (3, 1, 10),
     (3, 2, 8),
     (3, 2, 9),
     (3, 2, 10),
     (3, 3, 8),
     (3, 3, 9),
     (3, 3, 10)]
    

    如果所有范围的低值为0,则可以使用np.ndindex

    In [52]: values = list(np.ndindex(4, 4, 11))
    
    In [53]: values[:5]
    Out[53]: [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), (0, 0, 4)]
    
    In [54]: values[-5:]
    Out[34]: [(3, 3, 6), (3, 3, 7), (3, 3, 8), (3, 3, 9), (3, 3, 10)]
    

    【讨论】:

    • 不,所有的低值都不是 0,因此我认为我不能使用np.ndindex。另一种方法对我有用。一旦我有了元组列表,我就可以将它转换为一个 numpy 数组。谢谢!!
    • 我刚刚注意到,运行该方法 100000 次,我的方法在 9 秒内给出结果,而使用 itertools 需要 44 秒。这种方法编码起来要简单得多,但我在看效率,因为我必须做几百万次。
    猜你喜欢
    • 1970-01-01
    • 2023-01-23
    • 1970-01-01
    • 1970-01-01
    • 2013-01-04
    • 1970-01-01
    • 2021-09-09
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多