正如我在a comment 中提到的,矢量化不再是一个巨大的优势。因此,有一些矢量化方法可以减慢代码速度而不是加快速度。您必须始终为您的解决方案计时。矢量化通常涉及创建大型临时数组或复制大量数据,这些在循环代码中被避免。如果这样的解决方案要更快,这取决于架构、输入的大小以及许多其他因素。
尽管如此,在这种情况下,矢量化方法似乎可以产生很大的加速。
关于原始代码,首先要注意的是X(:, i, 1) .* X(:, j, 2) 在内部循环中被重新计算,尽管它在那里是一个常量值。重写内部循环,这样可以节省时间:
Y = X(:, i, 1) .* X(:, j, 2);
for k = 1:p
T(i, j, k) = sum(Y .* X(:, k, 3));
end
现在我们注意到内循环是一个点积,可以写成:
Y = X(:, i, 1) .* X(:, j, 2);
T(i, j, :) = Y.' * X(:, :, 3);
Y 上的.' 转置不会复制数据,因为Y 是一个向量。接下来,我们注意到X(:, :, 3) 被重复索引。让我们把它移出外循环。现在我剩下以下代码:
T = zeros(p, p, p);
X1 = X(:, :, 1);
X2 = X(:, :, 2);
X3 = X(:, :, 3);
for i = 1:p
for j = 1:p
Y = X1(:, i) .* X2(:, j);
T(i, j, :) = Y.' * X3;
end
end
删除j 上的循环很可能同样容易,这将在i 上留下一个循环。但这就是我停下来的地方。
这是我看到的时间安排(R2017a,3 岁的 4 核 iMac)。对于n=10, p=20:
original: 0.0206
moving Y out the inner loop: 0.0100
removing inner loop: 0.0016
moving indexing out of loops: 7.6294e-04
Luis' answer: 1.9196e-04
对于更大的数组,n=50, p=100:
original: 2.9107
moving Y out the inner loop: 1.3488
removing inner loop: 0.0910
moving indexing out of loops: 0.0361
Luis' answer: 0.1417
“路易斯的回答”是this one。对于小型阵列,它是迄今为止最快的,但对于较大的阵列,它显示了排列的成本。将第一个乘积的计算移出内部循环可以节省一半以上的计算成本。但是删除内部循环会大大降低成本(我没有预料到,我认为单矩阵产品可以比许多小的元素产品更好地使用并行性)。然后,我们通过减少循环内的索引操作量来进一步减少时间。
这是计时码:
function so()
n = 10; p = 20;
%n = 50; p = 100;
X = randn(n,p,3);
T1 = method1(X);
T2 = method2(X);
T3 = method3(X);
T4 = method4(X);
T5 = method5(X);
assert(max(abs(T1(:)-T2(:)))<1e-13)
assert(max(abs(T1(:)-T3(:)))<1e-13)
assert(max(abs(T1(:)-T4(:)))<1e-13)
assert(max(abs(T1(:)-T5(:)))<1e-13)
timeit(@()method1(X))
timeit(@()method2(X))
timeit(@()method3(X))
timeit(@()method4(X))
timeit(@()method5(X))
function T = method1(X)
p = size(X,2);
T = zeros(p, p, p);
for i = 1:p
for j = 1:p
for k = 1:p
T(i, j, k) = sum(X(:, i, 1) .* X(:, j, 2) .* X(:, k, 3));
end
end
end
function T = method2(X)
p = size(X,2);
T = zeros(p, p, p);
for i = 1:p
for j = 1:p
Y = X(:, i, 1) .* X(:, j, 2);
for k = 1:p
T(i, j, k) = sum(Y .* X(:, k, 3));
end
end
end
function T = method3(X)
p = size(X,2);
T = zeros(p, p, p);
for i = 1:p
for j = 1:p
Y = X(:, i, 1) .* X(:, j, 2);
T(i, j, :) = Y.' * X(:, :, 3);
end
end
function T = method4(X)
p = size(X,2);
T = zeros(p, p, p);
X1 = X(:, :, 1);
X2 = X(:, :, 2);
X3 = X(:, :, 3);
for i = 1:p
for j = 1:p
Y = X1(:, i) .* X2(:, j);
T(i, j, :) = Y.' * X3;
end
end
function T = method5(X)
T = sum(permute(X(:,:,1), [2 4 5 3 1]) .* permute(X(:,:,2), [4 2 5 3 1]) .* permute(X(:,:,3), [4 5 2 3 1]), 5);