【问题标题】:Tensor multiplication w/o looping in MatlabMatlab中没有循环的张量乘法
【发布时间】:2018-08-19 09:01:23
【问题描述】:

我有一个 3d 数组 A,例如A=rand(N,N,K)。

我需要一个数组 B s.t.

B(n,m) = norm(A(:,:,n)*A(:,:,m)' - A(:,:,m)*A(:,:,n)','fro')^2 for all indices n,m in 1:K.

这是循环代码:

B = zeros(K,K);    
for n=1:K
       for m=1:K
           B(n,m) = norm(A(:,:,n)*A(:,:,m)' - A(:,:,m)*A(:,:,n)','fro')^2;
       end
end

我不想循环 1:K。

我可以创建一个大小为 NK x NK s.t 的数组 An_x_mt。

An_x_mt equals A(:,:,n)*A(:,:,m)' for all n,m in 1:K by
An_x_mt = Ar*Ac_t; 

Ac_t=reshape(permute(A,[2 1 3]),size(A,1),[]); 
Ar=Ac_t';

如何创建大小为 NK x NK s.t 的数组 Am_x_nt。

Am_x_nt equals A(:,:,m)*A(:,:,n)' for all n,m in 1:K

这样我就可以了

B = An_x_mt  - Am_x_nt
B = reshape(B,N,N,[]);
B = reshape(squeeze(sum(sum(B.^2,1),2)),K,K);

谢谢

【问题讨论】:

  • "我不想循环 1:K。"为什么不?您确定这是您代码中的瓶颈吗?你确定没有循环会更快吗?你对reshapepermute 所做的任何事情都比循环更难阅读,因此维护成本更高。
  • 是的,这是一个主要瓶颈。数组乘法更快
  • 这是 Frobenius 范数 mathworld.wolfram.com/FrobeniusNorm.html 矩阵绝对值平方和
  • 问题是如何在不循环的情况下计算 Am_x_nt(n,m) = A(:,:,m)*A(:,:,n)'。
  • mmxmtimesx

标签: matlab loops tensor


【解决方案1】:

对于那些不能/不会使用 mmx 并希望坚持纯 Matlab 代码的人,您可以这样做。 mat2cell 和 cell2mat 函数是你的朋友:

[N,~,nmat]=size(A);
Atc = reshape(permute(A,[2 1 3]),N,[]); % A', N x N*nmat
Ar = Atc'; % A, N*nmat x N
Anmt_2d = Ar*Atc; % An*Am'
Anmt_2d_cell = mat2cell(Anmt_2d,N*ones(nmat,1),N*ones(nmat,1));
Amnt_2d_cell = Anmt_2d_cell'; % ONLY products transposed, NOT their factors
Amnt_2d = cell2mat(Amnt_2d_cell); % Am*An'
Anm = Anmt_2d - Amnt_2d;
Anm = Anm.^2;
Anm_cell = mat2cell(Anm,N*ones(nmat,1),N*ones(nmat,1));
d = cellfun(@(c) sum(c(:)), Anm_cell); % squared Frobenius norm of each product; nmat x nmat

或者,在计算 Anmt_2d_cell 和 Amnt_2d_cell 之后,您可以将它们转换为 3d,其中第 3 维编码 (n,m) 和 (m,n) 索引,然后在 3d 中执行其余计算。您将需要来自这里的 permn() 实用程序https://www.mathworks.com/matlabcentral/fileexchange/7147-permn-v-n-k

Anmt_3d = cat(3,Anmt_2d_cell);
Amnt_3d = cat(3,Amnt_2d_cell);
Anm_3d = Anmt_3d - Amnt_3d;
Anm_3d = Anm_3d.^2;
Anm = squeeze(sum(sum(Anm_3d,1),2));
d = zeros(nmat,nmat);
nm=permn(1:nmat, 2); % all permutations (n,m) with repeat, by-row order
d(sub2ind([nmat,nmat],nm(:,1),nm(:,2))) = Anm;

由于某种原因,第二个选项(3D 数组)要快两倍。

希望这会有所帮助。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-07-01
    • 1970-01-01
    • 2015-10-08
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多