【问题标题】:Is Arrays.stream(array_name).sum() slower than iterative approach?Arrays.stream(array_name).sum() 是否比迭代方法慢?
【发布时间】:2015-01-13 15:46:09
【问题描述】:

我正在编写一个 leetcode 问题:https://oj.leetcode.com/problems/gas-station/ 使用 Java 8。

当我使用Arrays.stream(integer_array).sum() 计算总和时,我的解决方案获得了 TLE,而使用迭代计算数组中元素的总和时,相同的解决方案被接受。这个问题的最佳时间复杂度是 O(n),我很惊讶在使用 Java 8 的流 API 时得到 TLE。我只在 O(n) 中实现了解决方案。

import java.util.Arrays;

public class GasStation {
    public int canCompleteCircuit(int[] gas, int[] cost) {
        int start = 0, i = 0, runningCost = 0, totalGas = 0, totalCost = 0; 
        totalGas = Arrays.stream(gas).sum();
        totalCost = Arrays.stream(cost).sum();

        // for (int item : gas) totalGas += item;
        // for (int item : cost) totalCost += item;

        if (totalGas < totalCost)
            return -1;

        while (start > i || (start == 0 && i < gas.length)) {
            runningCost += gas[i];
            if (runningCost >= cost[i]) {
                runningCost -= cost[i++];
            } else {
                runningCost -= gas[i];
                if (--start < 0)
                    start = gas.length - 1;
                runningCost += (gas[start] - cost[start]);
            }
        }
        return start;
    }

    public static void main(String[] args) {
        GasStation sol = new GasStation();
        int[] gas = new int[] { 10, 5, 7, 14, 9 };
        int[] cost = new int[] { 8, 5, 14, 3, 1 };
        System.out.println(sol.canCompleteCircuit(gas, cost));

        gas = new int[] { 10 };
        cost = new int[] { 8 };
        System.out.println(sol.canCompleteCircuit(gas, cost));
    }
}

解决方案被接受时, 我评论以下两行:(使用流计算总和)

totalGas = Arrays.stream(gas).sum();
totalCost = Arrays.stream(cost).sum();

并取消注释以下两行(使用迭代计算总和):

//for (int item : gas) totalGas += item;
//for (int item : cost) totalCost += item;

现在解决方案被接受。为什么 Java 8 中的流式 API 对于大输入比原语迭代慢?

【问题讨论】:

  • here 的结果是针对集合(列表)而不是基元计算的。原语没有类似 list.forEach((i) -> doIt(i)); ,相反我们必须使用 Arrays 实用程序。对于集合,Java 流式处理、并行性和缩减比迭代更快。我仍然怀疑原始流的结果如何比正常迭代慢。 leetcode 针对庞大的数据集测试我的解决方案。

标签: java performance algorithm java-8


【解决方案1】:

处理此类问题的第一步是将代码置于受控环境中。这意味着在您控制(并且可以调用)的 JVM 中运行它,并在像 JMH 这样的良好基准工具中运行测试。分析,不要推测。

这是我使用 JMH 建立的一个基准,用于对此进行一些分析:

@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Benchmark)
public class ArraySum {
    static final long SEED = -897234L;

    @Param({"1000000"})
    int sz;

    int[] array;

    @Setup
    public void setup() {
        Random random = new Random(SEED);
        array = new int[sz];
        Arrays.setAll(array, i -> random.nextInt());
    }

    @Benchmark
    public int sumForLoop() {
        int sum = 0;
        for (int a : array)
            sum += a;
        return sum;
    }

    @Benchmark
    public int sumStream() {
        return Arrays.stream(array).sum();
    }
}

基本上,这会创建一个包含一百万个整数的数组并将它们相加两次:一次使用 for 循环,一次使用流。运行基准测试会产生一堆输出(为了简洁和戏剧效果而省略),但总结结果如下:

Benchmark                 (sz)  Mode  Samples     Score  Score error  Units
ArraySum.sumForLoop    1000000  avgt        3   514.473      398.512  us/op
ArraySum.sumStream     1000000  avgt        3  7355.971     3170.697  us/op

哇! Java 8 流的东西就是 SUXX0R!它比 for 循环慢 14 倍,不要使用它!!!1!

嗯,不。首先让我们回顾一下这些结果,然后再仔细观察,看看我们是否能弄清楚发生了什么。

摘要显示了两种基准方法,“sz”参数为一百万。可以更改此参数,但在这种情况下不会产生影响。我也只运行了 3 次基准测试方法,正如您从“样本”列中看到的那样。 (也只有 3 次预热迭代,此处不可见。)每次操作的分数以微秒为单位,显然流代码比 for 循环代码慢得多。但还要注意分数错误:这是不同运行中的可变性量。 JMH 有助于打印出结果的标准偏差(此处未显示),但您可以很容易地看到分数错误是报告分数的很大一部分。这降低了我们对分数的信心。

运行更多迭代应该会有所帮助。更多的预热迭代将使 JIT 在运行基准测试之前做更多的工作并稳定下来,并且运行更多的基准测试迭代将消除我系统上其他地方的瞬态活动产生的任何错误。所以让我们尝试 10 次预热迭代和 10 次基准迭代:

Benchmark                 (sz)  Mode  Samples     Score  Score error  Units
ArraySum.sumForLoop    1000000  avgt       10   504.803       34.010  us/op
ArraySum.sumStream     1000000  avgt       10  7128.942      178.688  us/op

性能总体上快了一点,测量误差也小了很多,所以运行更多的迭代已经达到了预期的效果。但是流代码仍然比 for 循环代码慢得多。怎么回事?

通过查看streams方法的各个时序可以获得一个很大的线索:

# Warmup Iteration   1: 570.490 us/op
# Warmup Iteration   2: 491.765 us/op
# Warmup Iteration   3: 756.951 us/op
# Warmup Iteration   4: 7033.500 us/op
# Warmup Iteration   5: 7350.080 us/op
# Warmup Iteration   6: 7425.829 us/op
# Warmup Iteration   7: 7029.441 us/op
# Warmup Iteration   8: 7208.584 us/op
# Warmup Iteration   9: 7104.160 us/op
# Warmup Iteration  10: 7372.298 us/op

发生了什么?最初的几次迭代相当快,但随后的第 4 次和后续迭代(以及随后的所有基准迭代)突然变得慢了很多。

我以前见过这个。它位于 SO 上的 this questionthis answer 中。我建议阅读该答案;它解释了在这种情况下 JVM 的内联决策如何导致性能下降。

这里有一点背景知识:for 循环编译为一个非常简单的增量和测试循环,并且可以通过循环剥离和展开等常用优化技术轻松处理。在这种情况下,流代码虽然不是很复杂,但与 for 循环代码相比实际上相当复杂;有相当多的设置,每个循环至少需要一个方法调用。因此,JIT 优化,尤其是其内联决策,对于使流代码快速运行至关重要。而且它有可能出错。

另一个背景点是整数求和是您可以想到的在循环或流中执行的最简单的操作。这将使流设置的固定开销看起来相对更昂贵。它也非常简单,可以触发内联策略中的问题。

另一个答案的建议是添加 JVM 选项 -XX:MaxInlineLevel=12 以增加可以内联的代码量。使用该选项重新运行基准测试会给出:

Benchmark                 (sz)  Mode  Samples    Score  Score error  Units
ArraySum.sumForLoop    1000000  avgt       10  502.379       27.859  us/op
ArraySum.sumStream     1000000  avgt       10  498.572       24.195  us/op

啊,好多了。使用-XX:-TieredCompilation 禁用分层编译也具有避免病态行为的效果。我还发现使循环计算更加昂贵,例如对整数的平方求和——即相加——也可以避免这种病态行为。

现在,您的问题是关于在 leetcode 环境的上下文中运行,这似乎是在您无法控制的 JVM 中运行代码,因此您无法更改内联或编译选项.而且您可能也不希望使您的计算更复杂以避免病理。因此,对于这种情况,您不妨坚持使用旧的 for 循环。但是不要害怕使用流,即使是处理原始数组。除了一些狭窄的边缘情况外,它可以表现得很好。

【讨论】:

    【解决方案2】:

    正常的迭代方法将几乎与任何方法一样快,但流有各种开销:即使它直接来自流,也可能会涉及到一个原始的 Spliterator 和很多正在生成的其他对象。

    通常,您应该期望“正常方法”通常比流更快,除非您同时使用并行化并且您的数据非常大。 p>

    【讨论】:

    • 虽然原则上是正确的,但差异应该在一个合理的范围内。两种解决方案仍然是O(n),因此除非对初始化时间有非常严格的限制,否则不应出现一个解决方案被接受而另一个解决方案因超时而被拒绝的情况。
    • 不要排除流方法可能比直接迭代更快——这种情况经常发生。编写合理且正确的代码;如果您的情况对性能非常敏感,以至于可以产生影响,那么无论如何您都已经定义了性能要求和出色的性能测试,并且您将能够衡量差异。
    • 看低级。 #1 原因是:通过 Spliterator 的每个元素访问开销远低于 Iterator。这会产生巨大的差异。
    • @LouisWasserman 两者都有贡献,但是是正交的。 forEachRemaining 肯定有很大帮助——但即使是一次访问,Spliterator 使用的周期/缓存未命中/分支来获取元素也比 Iterator 少。
    • 如果您只运行一次管道,您将运行完全解释。我们针对热代码的长期性能进行了优化。
    【解决方案3】:

    我的基准测试(见下面的代码)显示流式方法比迭代法慢 10-15%。有趣的是,并行流结果在我的 4 核 (i7) macbook pro 上差异很大,但是,虽然我见过几次它们比迭代快大约 30%,但最常见的结果几乎是三倍 慢 比顺序。

    这里是基准代码:

    import java.util.*;
    import java.util.function.*;
    
    public class StreamingBenchmark {
    
        private static void benchmark(String name, LongSupplier f) {
           long start = System.currentTimeMillis(), sum = 0;
           for(int count = 0; count < 1000; count ++) sum += f.getAsLong();
           System.out.println(String.format(
               "%10s in  %d millis. Sum = %d", 
                name, System.currentTimeMillis() - start, sum
           ));
        }
    
        public static void main(String argv[]) {
            int data[] = new int[1000000];
            Random randy = new Random();
            for(int i = 0; i < data.length; i++) data[i] = randy.nextInt();
    
            benchmark("iterative", () -> { int s = 0; for(int n: data) s+=n; return s; });
            benchmark("stream", () -> Arrays.stream(data).sum());
            benchmark("parallel", () -> Arrays.stream(data).parallel().sum());
    
        }
    }
    

    这是几次运行的输出:

     iterative in  350 millis. Sum = 564821058000
     stream in  394 millis. Sum = 564821058000
     parallel in  883 millis. Sum = 564821058000
    
     iterative in  340 millis. Sum = -295411382000
     stream in  376 millis. Sum = -295411382000
     parallel in  1031 millis. Sum = -295411382000
    
     iterative in  365 millis. Sum = 1205763898000
     stream in  379 millis. Sum = 1205763898000
     parallel in  1053 millis. Sum = 1205763898000
    

    等等

    这让我很好奇,我也尝试在 scala 中运行等效逻辑:

    object Scarr {
        def main(argv: Array[String]) = {
            val randy = new java.util.Random
            val data = (1 to 1000000).map { _ => randy.nextInt }.toArray
            val start = System.currentTimeMillis
            var sum = 0l;
            for ( _ <- 1 to 1000 ) sum += data.sum
            println(sum + " in " + (System.currentTimeMillis - start) + " millis.")
    
        }
    }
    

    这花了 14 秒!比 java 中的流式传输长约 40 倍(!)。哎哟!

    【讨论】:

    • 我很惊讶地看到 SPOJ 和 leetcode 中的一些解决方案获得了 TLE,只是我们使用的是 Java 8 流 API。具有讽刺意味的是,在某些情况下,它们比正常的迭代方法慢 4-5 倍。 oj.leetcode.com/problems/gas-station Java 允许的时间限制约为 2 秒。使用普通的迭代方法可以在 200 毫秒内解决问题,而使用流 API 的解决方案会获得 TLE(花费超过 2 秒)。请参阅我上面的解决方案。令人惊讶的是,对于百万的输入大小,这至少慢了 10 倍。
    • 这不是对代码进行基准测试的正确方法。您的结果数字可能会出现偏差,因为 JVM 可以应用循环优化、执行死代码消除或常量折叠。使用 JMH 进行基准测试。
    【解决方案4】:

    sum() 方法在语法上等同于return reduce(0, Integer::sum); 在一个大列表中,进行所有方法调用的开销将比基本的手动 for 循环迭代更多。 for(int i : numbers) 迭代的字节码仅比手动 for 循环生成的字节码稍微复杂一点。流操作在并行友好的环境中可能更快(尽管可能不适用于原始方法),但除非我们知道环境是并行友好的(并且可能不是,因为 leetcode 本身可能被设计为支持低级而不是抽象因为它衡量的是效率而不是易读性)。

    三种方式(Arrays.stream(int[]).sumfor (int i : ints){total+=i;}for(int i = 0; i &lt; ints.length; i++){total+=i;})中任意一种方式的求和运算在效率上应该是比较相似的。我使用了下面的测试类(对0到4096之间的一亿个整数求和)每次一百次并记录平均时间)。所有这些都在非常相似的时间范围内返回。它甚至试图通过在 while(true) 循环中占用除一个可用内核之外的所有可用内核来限制并行处理,但我仍然没有发现特别的区别:

    public class SumTester {
        private static final int ARRAY_SIZE = 100_000_000;
        private static final int ITERATION_LIMIT = 100;
        private static final int INT_VALUE_LIMIT = 4096;
    
        public static void main(String[] args) {
            Random random = new Random();
            int[] numbers = new int[ARRAY_SIZE];
            IntStream.range(0, ARRAY_SIZE).forEach(i->numbers[i] = random.nextInt(INT_VALUE_LIMIT));
    
            Map<String, ToLongFunction<int[]>> inputs = new HashMap<String, ToLongFunction<int[]>>();
    
            NanoTimer initializer = NanoTimer.start();
            System.out.println("initialized NanoTimer in " + initializer.microEnd() + " microseconds");
    
            inputs.put("sumByStream", SumTester::sumByStream);
            inputs.put("sumByIteration", SumTester::sumByIteration);
            inputs.put("sumByForLoop", SumTester::sumByForLoop);
    
            System.out.println("Parallelables: ");
            averageTimeFor(ITERATION_LIMIT, inputs, Arrays.copyOf(numbers, numbers.length));
    
            int cores = Runtime.getRuntime().availableProcessors();
            List<CancelableThreadEater> threadEaters = new ArrayList<CancelableThreadEater>();
            if (cores > 1){
                threadEaters = occupyThreads(cores - 1);
            }
            // Only one core should be left to our class
            System.out.println("\nSingleCore (" + threadEaters.size() + " of " + cores + " cores occupied)");
            averageTimeFor(ITERATION_LIMIT, inputs, Arrays.copyOf(numbers, numbers.length));
            for (CancelableThreadEater cte : threadEaters){
                cte.end();
            }
            System.out.println("Complete!");
        }
    
        public static long sumByStream(int[] numbers){
            return Arrays.stream(numbers).sum();
        }
    
        public static long sumByIteration(int[] numbers){
            int total = 0;
            for (int i : numbers){
                total += i;
            }
            return total;
        }
    
        public static long sumByForLoop(int[] numbers){
            int total = 0;
            for (int i = 0; i < numbers.length; i++){
                total += numbers[i];
            }
            return total;       
        }
    
        public static void averageTimeFor(int iterations, Map<String, ToLongFunction<int[]>> testMap, int[] numbers){
            Map<String, Long> durationMap = new HashMap<String, Long>();
            Map<String, Long> sumMap = new HashMap<String, Long>();
            for (String methodName : testMap.keySet()){
                durationMap.put(methodName, 0L);
                sumMap.put(methodName, 0L);
            }
            for (int i = 0; i < iterations; i++){
                for (String methodName : testMap.keySet()){
                    int[] newNumbers = Arrays.copyOf(numbers, ARRAY_SIZE);
                    ToLongFunction<int[]> function = testMap.get(methodName);
                    NanoTimer nt = NanoTimer.start();
                    long sum = function.applyAsLong(newNumbers);
                    long duration = nt.microEnd();
                    sumMap.put(methodName, sum);
                    durationMap.put(methodName, durationMap.get(methodName) + duration);
                }
            }
            for (String methodName : testMap.keySet()){
                long duration = durationMap.get(methodName) / iterations;
                long sum = sumMap.get(methodName);
                System.out.println(methodName + ": result '" + sum + "', elapsed time: " + duration + " milliseconds on average over " + iterations + " iterations");
            }
        }
    
        private static List<CancelableThreadEater> occupyThreads(int numThreads){
            List<CancelableThreadEater> result = new ArrayList<CancelableThreadEater>();
            for (int i = 0; i < numThreads; i++){
                CancelableThreadEater cte = new CancelableThreadEater();
                result.add(cte);
                new Thread(cte).start();
            }
            return result;
        }
    
        private  static class CancelableThreadEater implements Runnable {
            private Boolean stop = false;
            public void run(){
                boolean canContinue = true;
                while (canContinue){
                    synchronized(stop){
                        if (stop){
                            canContinue = false;
                        }
                    }
                }           
            }
    
            public void end(){
                synchronized(stop){
                    stop = true;
                }
            }
        }
    
    }
    

    返回的

    initialized NanoTimer in 22 microseconds
    Parallelables: 
    sumByIteration: result '-1413860413', elapsed time: 35844 milliseconds on average over 100 iterations
    sumByStream: result '-1413860413', elapsed time: 35414 milliseconds on average over 100 iterations
    sumByForLoop: result '-1413860413', elapsed time: 35218 milliseconds on average over 100 iterations
    
    SingleCore (3 of 4 cores occupied)
    sumByIteration: result '-1413860413', elapsed time: 37010 milliseconds on average over 100 iterations
    sumByStream: result '-1413860413', elapsed time: 38375 milliseconds on average over 100 iterations
    sumByForLoop: result '-1413860413', elapsed time: 37990 milliseconds on average over 100 iterations
    Complete!
    

    也就是说,在这种情况下没有真正的理由执行 sum() 操作。您正在遍历每个数组,总共进行了三次迭代,最后一次可能是比正常时间更长的迭代。通过阵列的一次完整同时迭代和一次短路迭代可以正确计算。也许可以更有效地做到这一点,但我想不出比我更好的方法来做到这一点。我的解决方案最终成为图表上最快的 Java 解决方案之一 - 它运行时间为 223 毫秒,属于 Python 解决方案的中间包之一。

    如果您愿意,我会添加我的解决方案,但我希望实际问题在这里得到解答。

    【讨论】:

    • 这应该是评论,而不是答案。
    • 我明白,但这无关紧要,抱歉。回答帖子应仅包含对所述问题的完整回答。
    • 我为违反礼节道歉,并编辑了我的帖子以尝试回答实际问题。
    【解决方案5】:

    Stream 功能比较慢。所以在 leetcode 竞赛或任何算法竞赛中,总是更喜欢经典循环而不是流函数,因为大输入容易出现 TLE。这反过来又会导致处罚,从而影响您的最终排名。 这里提到了详细的解释https://stackoverflow.com/a/27994074/6185191

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2013-11-07
      • 2019-05-03
      • 1970-01-01
      • 2014-11-08
      • 2019-07-30
      • 2019-06-20
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多