【问题标题】:Vectorize find function in a for loop在 for 循环中矢量化查找函数
【发布时间】:2013-07-03 17:43:32
【问题描述】:

我有以下代码输出小于或等于array2 的每个元素的array1 的值。这两个数组的长度不同。这个 for 循环非常慢,因为数组很大(~500,000 元素)。仅供参考,两个数组始终按升序排列。

任何帮助使它成为矢量操作并加快它的速度将不胜感激。

我正在考虑使用“最近”选项的interp1() 的某种多步骤过程。然后找到对应的outArray 大于array2 的位置,然后以某种方式固定点……但我认为必须有更好的方法。

array2 = [5 6 18 25];
array1 = [1 5 9 15 22 24 31];
outArray = nan(size(array2));
for a =1:numel(array2)
    outArray(a) = array1(find(array1 <= array2(a),1,'last'));
end

返回:

outArray =    
     5     5    15    24

【问题讨论】:

    标签: algorithm matlab vectorization


    【解决方案1】:

    这是一种可能的矢量化:

    [~,idx] = max(cumsum(bsxfun(@le, array1', array2)));
    outArray = array1(idx);
    

    编辑:

    在最近的版本中,由于 JIT 编译,MATLAB 在执行良好的老式非矢量化循环方面已经相当出色。

    下面是一些类似于您的代码,它利用两个数组已排序这一事实(因此,如果 pos(a) = find(array1&lt;=array2(a), 1, 'last') 那么我们保证在下一次迭代中计算的 pos(a+1) 将不小于前一个 @987654325 @)

    pos = 1;
    idx = zeros(size(array2));
    for a=1:numel(array2)
        while pos <= numel(array1) && array1(pos) <= array2(a)
            pos = pos + 1;
        end
        idx(a) = pos-1;
    end
    %idx(idx==0) = [];      %# in case min(array2) < min(array1)
    outArray = array1(idx);
    

    注意:注释行处理array2 的最小值小于array1 的最小值(即find(array1&lt;=array2(a)) 为空时)的情况

    我对目前发布的所有方法进行了比较,这确实是最快的。长度为 N=5000 的向量的时序(使用 TIMEIT 函数执行)为:

    0.097398     # your code
    0.39127      # my first vectorized code
    0.00043361   # my new code above
    0.0016276    # Mohsen Nosratinia's code
    

    这里是 N=500000 的时间:

    (? too-long) # your code
    (out-of-mem) # my first vectorized code
    0.051197     # my new code above
    0.25206      # Mohsen Nosratinia's code
    

    .. 从您报告的最初 10 分钟缩短到 0.05 秒,这是一个相当不错的改进!

    如果你想重现结果,这里是测试代码:

    function [t,v] = test_array_find()
        %array2 = [5 6 18 25];
        %array1 = [1 5 9 15 22 24 31];
        N = 5000;
        array1 = sort(randi([100 1e6], [1 N]));
        array2 = sort(randi([min(array1) 1e6], [1 N]));
    
        f = {...
            @() func1(array1,array2);   %# Aero Engy
            @() func2(array1,array2);   %# Amro
            @() func3(array1,array2);   %# Amro
            @() func4(array1,array2);   %# Mohsen Nosratinia
        };
    
        t = cellfun(@timeit, f);
        v = cellfun(@feval, f, 'UniformOutput',false);
        assert( isequal(v{:}) )
    end
    
    function outArray = func1(array1,array2)
        %idx = arrayfun(@(a) find(array1<=a, 1, 'last'), array2);
        idx = zeros(size(array2));
        for a=1:numel(array2)
            idx(a) = find(array1 <= array2(a), 1, 'last');
        end
        outArray = array1(idx);
    end
    
    function outArray = func2(array1,array2)
        [~,idx] = max(cumsum(bsxfun(@le, array1', array2)));
        outArray = array1(idx);
    end
    
    function outArray = func3(array1,array2)
        pos = 1;
        lastPos = numel(array1);
        idx = zeros(size(array2));
        for a=1:numel(array2)
            while pos <= lastPos && array1(pos) <= array2(a)
                pos = pos + 1;
            end
            idx(a) = pos-1;
        end
        %idx(idx==0) = [];      %# in case min(array2) < min(array1)
        outArray = array1(idx);
    end
    
    function outArray = func4(array1,array2)
        [~,I] = sort([array1 array2]);
        a1size = numel(array1);
        J = find(I>a1size);
        outArray = nan(size(array2));
        for k=1:numel(J),
            if  I(J(k)-1)<=a1size,
                outArray(k) = array1(I(J(k)-1));
            else
                outArray(k) = outArray(k-1);
            end
        end
    end
    

    【讨论】:

    • 感谢您的帮助,但由于数组很大,这会导致内存不足错误。仅供参考,在具有 16GB 内存的 Linux 服务器上运行
    • 我认为它会发生。矢量化通常以增加内存使用为代价(上面的代码将构建一个大小为numel(array1)-by-numel(array2) 的中间矩阵,在您的情况下约为 18GB!)......您当前的代码如何在完整数据上执行,你能提供一些时间安排?
    • (1x223700)的array1和(1x223100)的array2大约需要10分钟。使用 interp1(array1,array1,array2,'nearest') 大约需要 0.4 秒。但是,其中大约 50% 是错误的,因为它选择了一些大于 array2 的点。我需要类似 interp1(...'lessthan') 的东西,但它不存在:)
    • @Amro 有趣的是,我在 10 分钟前也发布了类似的解决方案。看来今晚我不是唯一一个不眠不休的程序员 :-)
    • @MohsenNosratinia:我知道,我刚刚刷新了我的页面 :) 抱歉,我的基准测试包含你的旧代码!
    【解决方案2】:

    它变慢的一个原因是您将array1 中的所有元素与array2 中的所有元素进行比较,因此如果它们分别包含MN 元素,则复杂度为O(M*N)。但是,由于数组已经排序,因此有一个线性时间 O(M+N) 解决方案

    array2 = [5 6 18 25];
    array1 = [1 5 9 15 22 24 31];
    
    outArray = nan(size(array2));
    k1 = 1;
    n1 = numel(array1);
    n2 = numel(array2);
    
    ks = 1;
    while ks <= n2 && array2(ks) < array1(1)
        ks = ks + 1;
    end
    
    for k2=ks:n2
        while k1 < n1 && array2(k2) >= array1(k1+1) 
            k1 = k1+1;
        end
        outArray(k2) = array1(k1);
    end
    

    这是一个测试用例,用于测量每个方法运行两个长度为 500,000 的数组所需的时间。

    array2 = 1:500000;
    array1 = array2-1;
    
    tic
    outArray1 = nan(size(array2));
    k1 = 1;
    n1 = numel(array1);
    n2 = numel(array2);
    
    ks = 1;
    while ks <= n2 && array2(ks) < array1(1)
        ks = ks + 1;
    end
    
    for k2=ks:n2
        while k1 < n1 && array2(k2) >= array1(k1+1) 
            k1 = k1+1;
        end
        outArray1(k2) = array1(k1);
    end
    toc    
    
    tic
    outArray2 = nan(size(array2));
    for a =1:numel(array2)
        outArray2(a) = array1(find(array1 <= array2(a),1,'last'));
    end
    toc
    

    结果是

    Elapsed time is 0.067637 seconds.
    Elapsed time is 418.458722 seconds.
    

    【讨论】:

    • 哈,看来我们的想法是一样的。我发布了类似的东西:)
    • 现在您参考了我删除的旧解决方案。我要取消删除它并添加一条注释。
    【解决方案3】:

    注意: 这是我最初的解决方案,也是 Amro 答案中的基准。但是,它比我在其他答案中提供的线性时间解决方案要慢。

    它变慢的一个原因是您将array1 中的所有元素与array2 中的所有元素进行比较,因此如果它们包含MN 元素,则复杂度为O(M*N)。但是,您可以将它们连接起来并将它们排序在一起,并获得更快的复杂算法(M+N)*log2(M+N)。这是一种方法:

    array2 = [5 6 18 25];
    array1 = [1 5 9 15 22 24 31];
    
    [~,I] = sort([array1 array2]);
    a1size = numel(array1);
    J = find(I>a1size);
    outArray = nan(size(array2));
    for k=1:numel(J),
        if  I(J(k)-1)<=a1size,
            outArray(k) = array1(I(J(k)-1));
        else
            outArray(k) = outArray(k-1);
        end
    end
    
    disp(outArray)
    
    % Test using original code
    outArray = nan(size(array2));
    for a =1:numel(array2)
        outArray(a) = array1(find(array1 <= array2(a),1,'last'));
    end
    disp(outArray)
    

    串联的数组将是

    >> [array1 array2]
    ans =
         1     5     9    15    22    24    31     5     6    18    25
    

    >> [B,I] = sort([array1 array2])
    B =
         1     5     5     6     9    15    18    22    24    25    31
    I =
         1     2     8     9     3     4    10     5     6    11     7
    

    它表明在排序数组B 中,第一个5 来自连接数组中的第二个位置,第二个5 来自八个位置,依此类推。因此,要找到array1 中小于array2 中给定元素的最大元素,我们只需要遍历I 中大于array1 大小的所有索引(因此属于array2 ) 并返回并找到属于array1 的最近索引。 J 包含这些元素在向量I 中的位置:

    >> J = find(I>a1size)
    J =
         3     4     7    10
    

    现在 for 循环遍历这些索引并检查 I 中的索引是否正好位于从 J 引用的每个索引之前的索引是否属于 array1。如果它属于array1,它会从array1 中检索它的值,否则它会复制为先前索引找到的值。

    请注意,如果 array2 包含的元素小于 array1 中的最小元素,则您的代码和此代码都会失败。

    【讨论】:

      猜你喜欢
      • 2020-05-03
      • 2018-10-15
      • 1970-01-01
      • 2015-01-30
      • 2019-02-28
      • 2011-02-09
      • 2015-02-03
      相关资源
      最近更新 更多