【问题标题】:Java Iterative Merge Sort RuntimeJava 迭代合并排序运行时
【发布时间】:2018-04-19 06:54:30
【问题描述】:

我目前正在为学校开展一个项目,该项目要求我为不同的排序算法编写代码。最困难的部分是在给定长度为 2^N 的输入数组的情况下编写合并排序的迭代版本。我使用了一个名为 merge 的必需辅助方法来帮助迭代合并。

我的结构如下。给定一个 2^N 的数组(让我们使用 16 的数组大小来解释我的方法),我遍历数组查看每个 2 个整数,并使用 merge() 交换一个大于另一个的整数。这个过程将在长度为 16 的数组中发生 8 次。然后,我将遍历数组,查看每个 4 个整数 4 次。我会使用我的合并方法来合并每组 4 个中的两个有序对。然后,我会查看一个由 8 个整数组成的块……依此类推。我的代码贴在这里:

public static void MergeSortNonRec(long[] a) {
    //======================
    //FILL IN YOUR CODE HERE
    //======================    
    /*
    System.out.print("Our array is: ");
    printArray(a);
    System.out.println('\n');
    */
    int alength = a.length;
    int counter = 2;
    //the counter will iterate through levels 2n - 2 4 8 16 32 etc.
    int pointtracker = 0;
    //the point tracker will keep track of the position in the array
    while (counter <= alength) {
        long [] aux = new long [alength];
        int low = pointtracker;
        int high = pointtracker + counter - 1;
        int mid = (low + high)/2;

        merge(a, aux, low, mid, high);

        if (high < alength - 1) {
            pointtracker += counter; 
            //move to the next block
        }
        else {
            //if our high point is at the end of the array
            counter *= 2;
            pointtracker = 0;
            //start over at a[0], with a doubled counter
        }
    }
    /*
    System.out.print("Final array is: ");
    printArray(a);
    System.out.println('\n');
    */
}//MergeSortNonRec()

我的合并方法如下:

    private static void merge(long[] a, long[] aux, int lo, int mid, int hi) {

    // copy to aux[]
    for (int k = lo; k <= hi; k++) {
        aux[k] = a[k]; 
    }

    // merge back to a[]
    int i = lo, j = mid+1;
    for (int k = lo; k <= hi; k++) {
        if      (i > mid)           a[k] = aux[j++];
        else if (j > hi)            a[k] = aux[i++];
        else if (aux[j] < aux[i])   a[k] = aux[j++];
        else                        a[k] = aux[i++];
    }
}

递归解决方案更加优雅:

    private static void sort(long[] a, long[] aux, int lo, int hi) {
    if (hi <= lo) return;
    int mid = lo + (hi - lo) / 2;
    sort(a, aux, lo, mid);
    sort(a, aux, mid + 1, hi);
    merge(a, aux, lo, mid, hi);
}

public static void MergeSort(long[] a) {
    long[] aux = new long[a.length];
    sort(a, aux, 0, a.length-1);
}

我的问题是运行时。我的教授说合并排序的迭代版本,因为我们只输入长度为 2^N 的数组,应该比非迭代版本运行得更快。但是,我的迭代版本运行速度比大型集合中的递归版本慢。这是我的时间输出示例:

![runtime]: https://imgur.com/a/bzVuw "排序算法"

我可以做些什么来减少迭代归并排序的时间?

编辑:我想通了。我将我的 aux 实例移到了 while 循环之外,这以指数方式减少了时间。谢谢大家!

【问题讨论】:

    标签: java arrays sorting merge mergesort


    【解决方案1】:

    我可以做些什么来减少迭代归并排序的时间?

    Wiki 有一个迭代(自下而上)合并排序的简化示例:

    https://en.wikipedia.org/wiki/Merge_sort#Bottom-up_implementation

    为了减少时间,只对 aux[] 数组进行一次分配,不要在每次合并时复制数据,而是在每次合并时交换对数组的引用。

            long [] t = a;      // swap references
            a = aux;
            aux = t;
    

    如果数组的大小是 2 的奇数次幂,则需要复制一次数组或就地交换,而不是执行第一次合并。

    迭代归并排序应该比递归归并排序运行得更快

    假设两者的版本都经过合理优化,迭代归并排序通常会更快,但相对差异会随着数组大小的增加而减小,因为大部分时间将花费在 merge() 函数中,这对于迭代和递归合并排序。

    需要权衡取舍。递归版本将向堆栈推送和弹出长度 - 2 或 2*length - 2 对索引,而迭代则动态生成索引(可以保存在寄存器中)。似乎在更深层次的递归中,递归版本对缓存更友好,因为它在数组的一部分上运行,而迭代版本在每次遍历时总是在整个数组上运行,但我从未见过在这种情况下,使用递归合并排序会带来更好的整体性能。 PC 上的大多数缓存都是 4 路或更多路组关联的,因此两行用于输入,一根用于合并过程中的输出。在我的测试中,多线程迭代合并排序比单线程迭代合并排序快得多,因此我测试过的系统上的合并排序不受内存带宽限制。

    下面是迭代(自下而上)合并排序和测试程序的一些优化示例:

    package jsortbu;
    import java.util.Random;
    
    public class jsortbu {
        static void MergeSort(int[] a)          // entry function
        {
            if(a.length < 2)                    // if size < 2 return
                return;
            int[] b = new int[a.length];
            BottomUpMergeSort(a, b);
        }
    
        static void BottomUpMergeSort(int[] a, int[] b)
        {
        int n = a.length;
        int s = 1;                              // run size 
            if(1 == (GetPassCount(n)&1)){       // if odd number of passes
                for(s = 1; s < n; s += 2)       // swap in place for 1st pass
                    if(a[s] < a[s-1]){
                        int t = a[s];
                        a[s] = a[s-1];
                        a[s-1] = t;
                    }
                s = 2;
            }
            while(s < n){                       // while not done
                int ee = 0;                     // reset end index
                while(ee < n){                  // merge pairs of runs
                    int ll = ee;                // ll = start of left  run
                    int rr = ll+s;              // rr = start of right run
                    if(rr >= n){                // if only left run
                        do                      //   copy it
                            b[ll] = a[ll];
                        while(++ll < n);
                        break;                  //   end of pass
                    }
                    ee = rr+s;                  // ee = end of right run
                    if(ee > n)
                        ee = n;
                    Merge(a, b, ll, rr, ee);
                }
                {                               // swap references
                    int[] t = a;
                    a = b;
                    b = t;
                }
                s <<= 1;                        // double the run size
            }
        }
    
        static void Merge(int[] a, int[] b, int ll, int rr, int ee) {
            int o = ll;                         // b[]       index
            int l = ll;                         // a[] left  index
            int r = rr;                         // a[] right index
            while(true){                        // merge data
                if(a[l] <= a[r]){               // if a[l] <= a[r]
                    b[o++] = a[l++];            //   copy a[l]
                    if(l < rr)                  //   if not end of left run
                        continue;               //     continue (back to while)
                    do                          //   else copy rest of right run
                        b[o++] = a[r++];
                    while(r < ee);
                    break;                      //     and return
                } else {                        // else a[l] > a[r]
                    b[o++] = a[r++];            //   copy a[r]
                    if(r < ee)                  //   if not end of right run
                        continue;               //     continue (back to while)
                    do                          //   else copy rest of left run
                        b[o++] = a[l++];
                    while(l < rr);
                    break;                      //     and return
                }
            }
        }
    
        static int GetPassCount(int n)          // return # passes
        {
            int i = 0;
            for(int s = 1; s < n; s <<= 1)
                i += 1;
            return(i);
        }
    
        public static void main(String[] args) {
            int[] a = new int[10000000];
            Random r = new Random();
            for(int i = 0; i < a.length; i++)
                a[i] = r.nextInt();
            long bgn, end;
            bgn = System.currentTimeMillis();
            MergeSort(a);
            end = System.currentTimeMillis();
            for(int i = 1; i < a.length; i++){
                if(a[i-1] > a[i]){
                    System.out.println("failed");
                    break;
                }
            }
            System.out.println("milliseconds " + (end-bgn));
        }
    }
    

    【讨论】:

    • 我能够找出我的问题 - 只需将 long [] aux 移到 while 循环之外,因为它基本上是一个空数组,只是由 merge() 方法更改,但确实不需要每次传递都实例化。
    • @TVK - 我用合并排序的一些优化版本更新了我的答案。在我的系统(Intel 3770K 3.5ghz、Win 7 Pro 64 bit、NetBeans 8.1)上对 1000 万个整数进行排序大约需要 1.1 秒。
    猜你喜欢
    • 1970-01-01
    • 2017-03-01
    • 1970-01-01
    • 2014-12-19
    • 2015-12-10
    • 2023-03-15
    • 2018-03-08
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多