【问题标题】:Un-broadcasting Numpy arrays不广播 Numpy 数组
【发布时间】:2016-11-28 13:43:42
【问题描述】:

在大型代码库中,我使用np.broadcast_to 来广播数组(此处仅使用简单示例):

In [1]: x = np.array([1,2,3])

In [2]: y = np.broadcast_to(x, (2,1,3))

In [3]: y.shape
Out[3]: (2, 1, 3)

在代码的其他地方,我使用了第三方函数,这些函数可以在 Numpy 数组上以矢量化方式运行,但不是 ufunc。这些函数不理解广播,这意味着在像y 这样的数组上调用这样的函数是低效的。 Numpy 的vectorize 之类的解决方案也不好,因为虽然它们理解广播,但它们在数组元素上引入了for 循环,这样效率非常低。

理想情况下,我想做的是拥有一个函数,我们可以调用它,例如unbroadcast,它返回一个具有最小形状的数组,如果需要,可以将其广播回完整大小。例如:

In [4]: z = unbroadcast(y)

In [5]: z.shape
Out[5]: (1, 1, 3)

然后我可以在z 上运行第三方函数,然后将结果广播回y.shape

有没有办法实现依赖于 Numpy 的公共 API 的 unbroadcast?如果没有,是否有任何黑客可以产生预期的结果?

【问题讨论】:

  • 类似y[None,0]?
  • 什么是“最小形状”? unbroadcast 返回的 N 维的最小形状不总是 (1, 1, ..., 1) (甚至是 (1,) )吗?
  • 我的意思是最小的形状,它仍然包含将它广播回整个数组所需的所有数据。所以在上面的例子中,z.shape(1,1,3) 而不是 (1,1,1)

标签: python arrays numpy array-broadcasting


【解决方案1】:

我有一个可能的解决方案,所以会在这里发布(但是如果有人有更好的解决方案,也请随时回复!)。一种解决方案是检查数组的 strides 参数,该参数沿广播维度为 0:

def unbroadcast(array):
    slices = []
    for i in range(array.ndim):
        if array.strides[i] == 0:
            slices.append(slice(0, 1))
        else:
            slices.append(slice(None))
    return array[slices]

这给出了:

In [14]: unbroadcast(y).shape
Out[14]: (1, 1, 3)

【讨论】:

    【解决方案2】:

    这可能相当于您自己的解决方案,只是内置了一点。它在numpy.lib.stride_tricks 中使用as_strided

    import numpy as np
    from numpy.lib.stride_tricks import as_strided
    
    x = np.arange(16).reshape(2,1,8,1)  # shape (2,1,8,1)
    y = np.broadcast_to(x,(2,3,8,5))    # shape (2,3,8,5) broadcast
    
    def unbroadcast(arr):
        #determine unbroadcast shape
        newshape = np.where(np.array(arr.strides) == 0,1,arr.shape) # [2,1,8,1], thanks to @Divakar
        return as_strided(arr,shape=newshape)    # strides are automatically set here
    
    z = unbroadcast(x)
    np.all(z==x)  # is True
    

    请注意,在我的原始答案中,我没有定义函数,结果 z 数组具有 (64,0,8,0)strides,而输入具有 (64,64,8,8)。在当前版本中,返回的z 数组与x 具有相同的步幅,我猜传递和返回数组会强制创建一个副本。无论如何,我们总是可以在as_strided 中手动设置步幅,以在所有情况下获得相同的数组,但在上述设置中似乎没有必要。

    【讨论】:

    • np.where(np.array(y.strides) == 0,1,y.shape) 用于新形状?
    • @Divakar 对,我一直忘记np.where:) 显然这是 numpy 惯用的版本,谢谢。
    • 为了后代,编辑它以使两个关键行成为“未广播”功能可能会很好?
    • @astrofrog 谢谢,我希望任何尝试这样做的人都可以将它变成一个函数;)无论如何,我编辑了我的答案。事实证明,这会产生一个与x 具有相同步幅的数组。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2012-06-26
    • 1970-01-01
    • 1970-01-01
    • 2017-06-18
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多