根据@DSM 的评论,我假设您的C[k] == i 应该是B[k] == i。如果是这种情况,您的循环版本是否看起来像这样?
嵌套循环版本
import numpy as np
N = 5
M = 2
A = np.zeros((M,N))
B = np.random.randint(M, size=N) # contains indices for A
C = np.random.rand(N,N)
for i in range(M):
for j in range(N):
for k in range(N):
if B[k] == i:
A[i,j] += C[j,k]
有不止一种方法可以矢量化这个问题。我将在下面展示我的思考过程,但有更有效的方法可以做到这一点(例如,@DSM 的版本可以识别问题中固有的矩阵乘法)。
为了便于解释,这里是一种方法的演练。
矢量化内循环
让我们从重写内部k 循环开始:
for i in range(M):
for j in range(N):
A[i,j] = C[j, B == i].sum()
将其视为C[j][B == i].sum() 可能更容易。我们只是选择C 的jth 行,仅选择该行中B 等于i 的元素,然后将它们相加。
矢量化最外层循环
接下来让我们分解外部i 循环。现在我们要达到可读性开始受到影响的地步了,不幸的是......
i = np.arange(M)[:,np.newaxis]
mask = (B == i).astype(int)
for j in range(N):
A[:,j] = (C[j] * mask).sum(axis=-1)
这里有几个不同的技巧。在这种情况下,我们将遍历 A 的列。 A 的每一列是C 对应行的子集的总和。 C 行的子集由B 等于行索引i 的位置确定。
为了绕过对i 的迭代,我们通过向i 添加一个新轴来创建一个二维数组B == i。 (如果您对此感到困惑,请查看numpy broadcasting 的文档。)换句话说:
B:
array([1, 1, 1, 1, 0])
i:
array([[0],
[1]])
B == i:
array([[False, False, False, False, True],
[ True, True, True, True, False]], dtype=bool)
我们想要的是获取两个 (M) 过滤后的 C[j] 总和,一个对应于 B == i 中的每一行。这将为我们提供一个二元素向量,对应于A 中的jth 列。
我们不能通过直接索引C 来做到这一点,因为结果不会保持它的形状,因为每一行可能有不同数量的元素。我们将通过将B == i 掩码乘以C 的当前行来解决此问题,得到B == i 是False 的零,而C 的当前行中的值是真的。
为此,我们需要将布尔数组B == i 转换为整数:
mask = (B == i).astype(int):
array([[0, 0, 0, 0, 1],
[1, 1, 1, 1, 0]])
所以当我们将它乘以C的当前行时:
C[j]:
array([ 0.19844887, 0.44858679, 0.35370919, 0.84074259, 0.74513377])
C[j] * mask:
array([[ 0. , 0. , 0. , 0. , 0.74513377],
[ 0.19844887, 0.44858679, 0.35370919, 0.84074259, 0. ]])
然后我们可以对每一行求和得到A的当前列(这将在分配给A[:,j]时广播到一个列):
(C[j] * mask).sum(axis=-1):
array([ 0.74513377, 1.84148744])
全矢量化版本
最后,分解最后一个循环,我们可以应用完全相同的原理为j上的循环添加第三个维度:
i = np.arange(M)[:,np.newaxis,np.newaxis]
mask = (B == i).astype(int)
A = (C * mask).sum(axis=-1)
@DSM 的矢量化版本
正如@DSM 建议的那样,您也可以这样做:
A = (B == np.arange(M)[:,np.newaxis]).dot(C.T)
对于大多数尺寸的M 和N,这是迄今为止最快的解决方案,并且可以说是最优雅的(无论如何,比我的解决方案优雅得多)。
让我们稍微分解一下。
B == np.arange(M)[:,np.newaxis] 完全等同于上面“矢量化最外层循环”部分中的B == i。
关键在于认识到所有j 和k 循环都等效于矩阵乘法。 dot 会将布尔型 B == i 数组转换为与幕后 C 相同的 dtype,因此我们无需担心将其显式转换为不同的类型。
之后,我们只是对 C(一个 5x5 数组)的转置和上面的“掩码”0 和 1 数组执行矩阵乘法,得到一个 2x5 数组。
dot 将利用您已安装的任何优化的 BLAS 库(例如 ATLAS、MKL),因此它非常快。
时间
对于较小的 M 和 N,差异不太明显(循环和 DSM 版本之间的大约 6 倍):
M, N = 2, 5
%timeit loops(B,C,M)
10000 loops, best of 3: 83 us per loop
%timeit k_vectorized(B,C,M)
10000 loops, best of 3: 106 us per loop
%timeit vectorized(B,C,M)
10000 loops, best of 3: 23.7 us per loop
%timeit askewchan(B,C,M)
10000 loops, best of 3: 42.7 us per loop
%timeit einsum(B,C,M)
100000 loops, best of 3: 15.2 us per loop
%timeit dsm(B,C,M)
100000 loops, best of 3: 13.9 us per loop
但是,一旦 M 和 N 开始增长,差异就会变得非常显着(~600x)(注意单位!):
M, N = 50, 20
%timeit loops(B,C,M)
10 loops, best of 3: 50.3 ms per loop
%timeit k_vectorized(B,C,M)
100 loops, best of 3: 10.5 ms per loop
%timeit ik_vectorized(B,C,M)
1000 loops, best of 3: 963 us per loop
%timeit vectorized(B,C,M)
1000 loops, best of 3: 247 us per loop
%timeit askewchan(B,C,M)
1000 loops, best of 3: 493 us per loop
%timeit einsum(B,C,M)
10000 loops, best of 3: 134 us per loop
%timeit dsm(B,C,M)
10000 loops, best of 3: 80.2 us per loop