【问题标题】:Improving MATLAB Matrix Construction Code : Or, code Vectorization for beginners改进 MATLAB 矩阵构造代码:或者,初学者的代码向量化
【发布时间】:2013-07-06 14:44:41
【问题描述】:

我编写了一个程序来构造一个 3 波段小波变换矩阵的一部分。但是,考虑到矩阵的大小为 3^9 X 3^10,MATLAB 需要一段时间才能完成构造它。因此,我想知道是否有办法改进我正在使用的代码以使其运行得更快。我在运行代码时使用 n=10 。

B=zeros(3^(n-1),3^n);
v=[-0.117377016134830 0.54433105395181 -0.0187057473531300 -0.699119564792890 -0.136082763487960 0.426954037816980 ];

for j=1:3^(n-1)-1 
    for k=1:3^n;
        if k>6+3*(j-1) || k<=3*(j-1)
            B(j,k)=0;
        else 
            B(j,k)=v(k-3*(j-1));
        end                
    end
end
j=3^(n-1);
    for k=1:3^n
        if k<=3
            B(j,k)=v(k+3);
        elseif k<=3^n-3
            B(j,k)=0;
        else 
            B(j,k)=v(k-3*(j-1));
        end
    end

W=B;

【问题讨论】:

标签: matlab matrix vectorization wavelet


【解决方案1】:

如何在不知道如何矢量化的情况下进行矢量化:

首先,我将只讨论矢量化第一个双循环,您可以对第二个循环遵循相同的逻辑。

我试图从头开始展示一个思考过程,所以虽然最终答案只有 2 行长,但值得看看初学者如何尝试获得它。

首先,我建议在简单的情况下“按摩”代码,以感受一下。例如,我使用了n=3v=1:6 并运行了第一个循环,这就是B 的样子:

[N M]=size(B)
N =
     9
M =
    27

imagesc(B); 

所以你可以看到我们得到了一个像矩阵一样的楼梯,这是非常规则的!我们只需要将正确的矩阵索引分配给v 的正确值即可。

有很多方法可以实现这一点,有些方法比其他方法更优雅。最简单的方法之一是使用函数find

pos=[find(B==v(1)),find(B==v(2)),find(B==v(3)),...
     find(B==v(4)),find(B==v(5)),find(B==v(6))]

pos =
     1    10    19    28    37    46
    29    38    47    56    65    74
    57    66    75    84    93   102
    85    94   103   112   121   130
   113   122   131   140   149   158
   141   150   159   168   177   186
   169   178   187   196   205   214
   197   206   215   224   233   242

上面的值是矩阵Blinear indices,其中找到了v 的值。每列代表B 中特定值vlinear index。例如,索引[1 29 57 ...] 都包含值v(1),等等......每一行都包含v,因此,索引[29 38 47 56 65 74] 包含v=[v(1) v(2) ... v(6)]。可以注意到,每一行的索引差是9,或者说,每个索引之间有一个步长N,有6个,正好是向量v的长度(也是@得到的) 987654342@)。对于列,相邻元素之间的差为 28,或者步长为M+1

我们只需要根据这个逻辑在适当的索引中分配v 的值。一种方法是编写每个“行”:

B([1:N:numel(v)*N]+(M+1)*0)=v;
B([1:N:numel(v)*N]+(M+1)*1)=v;
...
B([1:N:numel(v)*N]+(M+1)*(N-2))=v;

但这对于大N-2 来说是不切实际的,所以如果你真的想要的话,你可以在 for 循环中做到这一点:

for kk=0:N-2;
     B([1:N:numel(v)*N]+(M+1)*kk)=v;
end

Matlab 提供了一种更有效的方法来使用 bsxfun 一次获取所有索引(这取代了 for 循环),例如:

ind=bsxfun(@plus,1:N:N*numel(v),[0:(M+1):M*(N-1)+1]')

所以现在我们可以使用indv 分配给矩阵N-1 次。为此,我们需要将ind“扁平化”为行向量:

ind=reshape(ind.',1,[]);

并将v 与自身连接N-1 次(或再制作N-1 个v 的副本):

vec=repmat(v,[1 N-1]);

我们终于得到了答案:

B(ind)=vec;

长话短说,写得紧凑,我们得到了一个 2 行解决方案(已知大小 B[N M]=size(B)):


ind=bsxfun(@plus,1:N:N*numel(v),[0:(M+1):M*(N-1)+1]');
B(reshape(ind.',1,[]))=repmat(v,[1 N-1]);

对于n=9,矢量化代码在我的机器上快了约 850 倍。 (小n 会不那么重要)

由于得到的矩阵大部分是由零组成的,你不需要将这些分配给一个完整的矩阵,而是使用一个稀疏矩阵,这里是完整的代码(非常相似):

N=3^(n-1);
M=3^n;
S=sparse([],[],[],N,M);
ind=bsxfun(@plus,1:N:N*numel(v),[0:(M+1):M*(N-1)+1]');
S(reshape(ind.',1,[]))=repmat(v,[1 N-1]);

对于n=10,我只能运行稀疏矩阵代码(否则内存不足),而在我的机器上大约需要 6 秒。

现在尝试将其应用于第二个循环...

【讨论】:

  • 哇,效果很好!试图弄清楚,但仍然很难弄清楚第二个循环。
【解决方案2】:

虽然您的矩阵具有巨大的维度,但它也非常“稀疏”,这意味着它的大部分元素都是零。为了提高性能,您可以利用MATLAB 的稀疏矩阵支持,确保您只对矩阵的非零部分进行操作。

MATLAB 中的稀疏矩阵可以通过构造稀疏矩阵的coordinate form 来高效地构建。这意味着必须定义三个数组,定义矩阵中每个非零条目的行、列和值。这意味着我们不会通过传统的A(i,j) = x 语法分配值,而是将非零条目附加到我们的稀疏索引结构中:

row(pos+1) = i;
col(pos+1) = j;
val(pos+1) = x;
% pos is the current position within the sparse indexing arrays!

一旦我们的稀疏索引数组中有完整的非零值集,我们就可以使用sparse 命令来构建矩阵。

对于这个问题,我们为每一行添加最多六个非零条目,允许我们提前分配稀疏索引数组。变量pos 跟踪我们在索引数组中的当前位置。

rows = 3^(n-1);
cols = 3^(n+0);

% setup the sparse indexing arrays for non-
% zero elements of matrix B
row = zeros(rows*6,1);
col = zeros(rows*6,1);
val = zeros(rows*6,1);
pos = +0;

我们现在可以通过向稀疏索引数组添加任何非零条目来构建矩阵。由于我们只关心非零条目,因此我们也只遍历矩阵的非零部分。

我把最后一行的逻辑留给你填写!

for j = 1 : 3^(n-1)
    if (j < 3^(n-1))

% add entries for a general row
    for k = max(1,3*(j-1)+1) : min(3^n,3*(j-1)+6)             
        pos = pos+1;
        row(pos) = j;
        col(pos) = k;
        val(pos) = v(k-3*(j-1));                
    end

    else

% add entries for final row - todo!!

    end
end

由于我们没有为每一行添加六个非零,我们可能过度分配了稀疏索引数组,因此我们将它们缩减到实际使用的大小。

% only keep the sparse indexing that we've used
row = row(1:pos);
col = col(1:pos);
val = val(1:pos);

现在可以使用sparse 命令构建最终矩阵。

% build the actual sparse matrix
B = sparse(row,col,val,rows,cols);

通过整理上面的sn-ps就可以运行代码了。对于n = 9,我们得到以下结果(作为比较,我还包括了natan 建议的基于bsxfun 的方法的结果):

Elapsed time is 2.770617 seconds. (original)
Elapsed time is 0.006305 seconds. (sparse indexing)
Elapsed time is 0.261078 seconds. (bsxfun)

n = 10 的原始代码内存不足,但两种稀疏方法仍然可用:

Elapsed time is 0.019846 seconds. (sparse indexing)
Elapsed time is 2.133946 seconds. (bsxfun)

【讨论】:

  • 这创造了奇迹!非常有效的方法来做到这一点。我得到了代码的最后一行,虽然我不知道我是否以最优雅的方式完成了它。不过非常感谢!
【解决方案3】:

您可以使用一种巧妙的方法来创建一个块对角矩阵,如下所示:

>> v=[-0.117377016134830 0.54433105395181 -0.0187057473531300 ...
-0.699119564792890 -0.136082763487960 0.426954037816980];
>>lendiff=长度(v)-3;
>> B=repmat([v zeros(1,3^n-lendiff)],3^(n-1),1);
>> B=重塑(B',3^(n),3^(n-1)+1);
>> B(:,end-1)=B(:,end-1)+B(:,end);
>> B=B(:,1:end-1)';

这里,lendiff 用于创建一行的 3^{n-1} 个副本,其中 v 后跟零,长度为 3^n+3,因此矩阵大小为 [3^{n -1} 3^n+3].

该矩阵被重新整形为大小 [3^n 3^{n-1}+1] 以创建移位。额外的列需要添加到最后,B需要转置。

不过应该​​会快得多。

编辑

看到 Darren 的解决方案并意识到 reshape 也适用于 sparse 矩阵,让我想出了这个 - 没有 for 循环(未编码原始解决方案)。

首先是开始的值:

>> v=[-0.117377016134830  ...
       0.54433105395181   ...
      -0.0187057473531300 ...
      -0.699119564792890  ...
      -0.136082763487960  ...
       0.426954037816980];    
>> rows = 3^(n-1);                  % same number of rows
>> cols = 3^(n)+3;                  % add 3 cols to implement the shifts    

然后使矩阵每行增加 3 列

>> row=(1:rows)'*ones(1,length(v)); % row number where each copy of v is stored'
>> col=ones(rows,1)*(1:length(v));  % place v at the start columns of each row
>> val=ones(rows,1)*v;              % fill in the values of v at those positions
>> B=sparse(row,col,val,rows,cols); % make the matrix B[rows cols+3], but now sparse

然后重塑以实现移位(额外的行,正确的列数)

>> B=reshape(B',3^(n),rows+1);      % reshape into B[3^n rows+1], shifted v per row'
>> B(1:3,end-1)=B(1:3,end);         % the extra column contains last 3 values of v
>> B=B(:,1:end-1)';                 % delete extra column after copying, transpose

对于 n=4,5,6,7,这会导致 s 中的 cpu 时间:

n    original    new version
4    0.033       0.000
5    0.206       0.000
6    1.906       0.000
7    16.311      0.000

由分析器测量。对于原始版本,我无法运行 n>7,但新版本给出了

n    new version
8    0.002
9    0.009
10   0.022
11   0.062
12   0.187
13   0.540
14   1.529
15   4.210

这就是我的 RAM 能走多远 :)

【讨论】:

  • +1 用于@darren 的稀疏索引解决方案。我花了一些时间来查看代码,但这样效率更高!
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2014-05-04
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多