【问题标题】:Efficiently multiply elements of each row together有效地将每一行的元素相乘
【发布时间】:2018-03-15 01:12:44
【问题描述】:

给定大小为(n, 3)n 的ndarray 大约为1000,如何快速将每行的所有元素相乘?下面的(不优雅的)第二种解决方案运行时间约为 0.3 毫秒,是否可以改进?

# dummy data
n = 999
a = np.random.uniform(low=0, high=10, size=n).reshape(n/3,3)

# two solutions
def prod1(array):
    return [np.prod(row) for row in array]

def prod2(array):
    return [row[0]*row[1]*row[2] for row in array]

# benchmark
start = time.time()
prod1(a)
print time.time() - start
# 0.0015

start = time.time()
prod2(a)
print time.time() - start
# 0.0003

【问题讨论】:

    标签: python arrays performance numpy multidimensional-array


    【解决方案1】:

    np.prod 接受轴参数:

    np.prod(a, axis=1)
    

    使用axis=1,会为每一行计算列乘积。

    完整性检查

    assert np.array_equal(np.prod(a, axis=1), prod1(a))
    

    性能

    17.6 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    

    (1000 倍加速)

    【讨论】:

    • (抱歉不接受,但另一个答案提供了更好的改进)
    • @anderstood 肯定 np,numpy 总是会输给 njit 解决方案。我会用 cython buuuut 进行报复,我认为它可能仍然不足。
    【解决方案2】:

    进一步提高性能

    首先是一般的经验法则。您正在使用数值数组,因此请使用数组而不是列表。列表可能看起来有点像一个通用数组,但在后端却完全不同,并且绝对不适用于大多数数值计算。

    如果您使用 Numpy-Arrays 编写一个简单的代码,您可以通过简单地对其进行 jitting 来获得性能,如图所示。如果你使用列表,你可以或多或少地重写你的代码。

    import numpy as np
    import numba as nb
    
    @nb.njit(fastmath=True)
    def prod(array):
      assert array.shape[1]==3 #Enable SIMD-Vectorization (adding some performance)
      res=np.empty(array.shape[0],dtype=array.dtype)
      for i in range(array.shape[0]):
        res[i]=array[i,0]*array[i,1]*array[i,2]
    
      return res
    

    使用np.prod(a, axis=1) 不是一个坏主意,但性能并不是很好。对于只有 1000x3 的数组,函数调用开销非常大。在另一个 jited 函数中使用 jited prod 函数时,可以完全避免这种情况。

    基准测试

    # The first call to the jitted function takes about 200ms compilation overhead. 
    #If you use @nb.njit(fastmath=True,cache=True) you can cache the compilation result for every successive call.
    n=999
    prod1   = 795  µs
    prod2   = 187  µs
    np.prod = 7.42 µs
    prod      0.85 µs
    
    n=9990
    prod1   = 7863 µs
    prod2   = 1810 µs
    np.prod = 50.5 µs
    prod      2.96 µs
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2023-03-19
      相关资源
      最近更新 更多