【问题标题】:Efficient way to calculate 3D matrix multiplication using numpy使用 numpy 计算 3D 矩阵乘法的有效方法
【发布时间】:2021-07-17 18:08:43
【问题描述】:

如何使用 numpy 有效地编写和计算这个乘法:

 for k in range(K):
    for i in range(SIZE):
       for j in range(SIZE):
          for i_b in range(B_SIZE):
             for j_b in range(B_SIZE):
                for k_b in range(k+1):
                   data[k, i * w + i_b, j * h + j_b] += arr1[k_b, i_b, j_b] * arr2[k_b, i, j]

例如:

SIZE, B_SIZE = 32, 8
arr1.shape -> (8, 8, 8)
arr2.shape -> (8, 32, 32)
data.shape -> (K, 256, 256)

谢谢。

【问题讨论】:

  • k8是什么关系?
  • 似乎einsummatmul 可以做到;至少有+=*。但是 6 个迭代器之间的映射很复杂(并且需要太多的工作:()。

标签: numpy matrix optimization matrix-multiplication numpy-ndarray


【解决方案1】:

您可以将 Numba 用于这种重要的情况,并重做循环以有效地使用 CPU 缓存。这是一个例子:

import numba as nb

@nb.njit
def compute(data, arr1, arr2):
    for k in range(K):
        for k_b in range(k+1):
            for i in range(SIZE):
                for j in range(SIZE):
                    tmp = arr2[k_b, i, j]
                    for i_b in range(B_SIZE):
                        for j_b in range(B_SIZE):
                            data[k, i * w + i_b, j * h + j_b] += arr1[k_b, i_b, j_b] * tmp

如果您执行此操作一次,则可以通过提供数组的类型预编译 Numba 代码。如果K 很大,那么您可以使用@nb.njit(parallel=True)并行化代码并使用for k in nb.prange(K) 而不是for k in range(K)。这应该是几个数量级。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2021-04-11
    • 2022-09-22
    • 2017-02-10
    • 2016-05-11
    • 2021-03-05
    • 1970-01-01
    • 2018-01-15
    • 2018-11-30
    相关资源
    最近更新 更多