【问题标题】:Optimize summation code优化求和代码
【发布时间】:2017-04-14 03:03:30
【问题描述】:

来自采访:

int fn(int a, int b)
{
    int sum = 0; 
    for (int i = a * 4; i > 0; i--)
    {
        sum += b * i * i;
    } 
    return sum;
}

如何进一步优化这段代码?我知道有一个求和公式,但我不认为记住这样的公式是面试官想要的。那么,你会如何优化它呢?

编辑:感谢 chqrlie、faivvy、asimes 和 Ap31 的建议和回答。所以我想现在将有三种优化方法:

  1. 我们可以在返回时进行,而不是在每次迭代中乘以 b。
  2. 用公式替换 for 循环:n * (n + 1) * (2 * n + 1) / 6 * b。 A simple derivation
  3. 使用循环展开。请看asimes发的帖子。

在这三个答案中,我可能会选择 1 和 3,因为它们可以应用于具有相似结构的所有类型的代码。您应该提到有一个公式可以用作奖励,但我怀疑公式是否是面试官想要的。

还有其他建议吗?

【问题讨论】:

  • b*i*i中删除b*,然后删除return sum*b
  • 也许他们并不在乎是否记住它,但他们只是想让您说公式存在并且可以轻松研究。加分,推导出公式或至少写出需要简化的求和表达式。
  • 他们是否认为a 总是积极的?另外,您熟悉循环展开吗?
  • 我会使用求和公式对其进行优化。在一次采访中,如果我知道有一个求和公式但无法立即回忆起来,我会提到我知道有这样的东西以及我将如何找到它(例如,我会检查哪些参考资料)。解决问题不需要知道所有的答案。它需要知道如何获得或制定合适的解决方案。
  • @SIR_Vampire:您可以通过单击分数下方的灰色复选标记来接受其中一个答案。

标签: c++ c performance optimization


【解决方案1】:

公式:1*1+2*​​2+...+n*n = n(n+1)(2n+1)/6

int fn(int a, int b)
{
    a <<= 2;
    return (a*(a + 1)*((a << 1) + 1) / 6) * b;
}

这是你想要的吗?

【讨论】:

  • 这不等同于原始代码,例如a 可能是否定的。在这个时代,显式位移看起来也有点不必要
  • @Ap31 是的,不等同于原代码。从编程的角度,我们应该考虑否定的情况。但他只是要求优化和公式,我会给出一个方向。
  • 好的,溢出情况怎么样?你的乘法很容易越界,也许最好不要事先做a &lt;&lt;= 2
  • @Ap31 即使按照他的代码sum += b*i*i;,它仍然很容易溢出。我的意思是他在写下这样的代码之前必须知道一些条件,比如 a。
  • 不,原始代码是否溢出无关紧要 - 重要的是,如果原始代码适用于 ab 的特定值,您的代码也应该适用于它们 -而事实并非如此
【解决方案2】:

除非a 为负数,否则函数fn 计算b 乘以直到4*a 的平方和。

1n 的平方和可以计算为n(n+1)(2n+1)/6

这是一个 C 翻译:

int fn(int a, int b) {
    if (a <= 0 || b == 0) {
        return 0;
    } else {
        int n = a * 4;
        return n * (n + 1) * (2 * n + 1) / 6 * b;
    }
}

正如 Ap31 所指出的,clang 足够精明,可以检测循环优化并将原始函数转换为直接计算,但它将上述代码编译为much more compact 16 assembly instructions(原始代码为 36)。

为了避免中间结果的潜在溢出,这里有一个稍微不同的公式,它不会计算更大的中间结果:

int fn(int a, int b) {
    if (a <= 0 || b == 0) {
        return 0;
    } else {
        if (a % 3 == 0)
            return (a / 3) * (4 * a + 1) * (8 * a + 1) * b * 2;
        else
            return (4 * a + 1) * (8 * a + 1) / 3 * a * b * 2;
    }
}

如果类型 long long 大于 int,则更简单的替代方法是:

int fn(int a, int b) {
    if (a <= 0 || b == 0) {
        return 0;
    } else {
        unsigned long long n = a * 4;
        return (int)(n * (n + 1) * (2 * n + 1) / 6 * b);
    }
}

【讨论】:

  • 天哪,我应该在深夜把我的cmets留给自己。我的坏
  • 正如我在 cmets 中对我的回答所提到的那样,由于乘法中可能溢出,这与原来的不完全等价,但否则你的答案看起来真的很好
  • @Ap31:您提出了一个有效点:中间结果 n * (n + 1) * (2 * n + 1) 可能在除以 6 和随后的乘以 b 之前溢出。根据b 的值,a 的某些值可能会产生错误的结果,而求和则不会。有两种方法可以解决这个问题:使用long longunsigned long long 进行中间计算,这仍然比求和快得多,但如果sizeof(int)==sizeof(long long) 可能不够用,或者找到一个没有更大的公式中间结果。
【解决方案3】:

面试官当然希望从@faivvy(和@chqrlie)的答案中得到优化,你总是可以推导出公式或者只是说你知道它存在并且你可以完全摆脱循环。

不要忘记一些常见的错误:a 可能是负数,a*a*(2*a + 1) 可能溢出。

另外需要注意的是modern compilers can do this by themselves - 你也可以向面试官提及。

【讨论】:

  • 我不知道这是否是面试官所期望的。如果我在他的位置,我不会。但话又说回来,我也不会一开始就问这个愚蠢的问题。
  • clang 确实令人印象深刻,但程序员还是可以通过一些思考做得更好:godbolt.org/g/B0mmbe
  • @chqrlie 再次,您的解决方案可能会溢出原始代码工作的乘法。 clang 的优化说明了这一点
【解决方案4】:

正如@faivvy 在他的回答中指出的那样,您可以尝试完全取消 for 循环

但是,另一种方法(正确处理负数a)是执行循环展开,我将调用该函数fnUnroll。如果您不熟悉循环展开,其想法是减少迭代次数并并行求和

正如cmets中提到的,每次迭代不需要乘以b,可以在最后完成。我添加了另一个名为 fnUnrollNoMult 的函数来显示这个

#include <chrono>
#include <cstdlib>
#include <iostream>

int fn(int a, int b) {
    int sum = 0;
    for (int i = a * 4; i > 0; i--)
        sum += b * i * i;
    return sum;
}

int fnUnroll(int a, int b) {
    // Set up some number of accumulators, I picked 4
    int sum0 = 0;
    int sum1 = 0;
    int sum2 = 0;
    int sum3 = 0;

    int i = 1;
    int limit = a * 4;

    // Sum 4 values in parallel
    for ( ; i < limit; i += 4) {
        sum0 += b * i * i;
        sum1 += b * (i + 1) * (i + 1);
        sum2 += b * (i + 2) * (i + 2);
        sum3 += b * (i + 3) * (i + 3);
    }

    // Handle the remainder (if any)
    for ( ; i < limit; i++)
        sum0 += b * i + i;

    // Sum the accumulators
    return sum0 + sum1 + sum2 + sum3;
}

int fnUnrollNoMult(int a, int b) {
    int sum0 = 0;
    int sum1 = 0;
    int sum2 = 0;
    int sum3 = 0;

    // Remove b from the loops
    int i = 1;
    int limit = a * 4;
    for ( ; i < limit; i += 4) {
        sum0 += i * i;
        sum1 += (i + 1) * (i + 1);
        sum2 += (i + 2) * (i + 2);
        sum3 += (i + 3) * (i + 3);
    }
    for ( ; i < limit; i++)
        sum0 += i + i;

    // Handle b here
    return b * (sum0 + sum1 + sum2 + sum3);
}

int main(int argc, char** argv) {
    // Expects two arguments: a and b
    if (argc != 3) {
        std::cout << "Usage: " << argv[0] << " <int> <int>\n";
        return 1;
    }

    int a = atoi(argv[1]);
    int b = atoi(argv[2]);

    // This is just to demonstrate correctness
    for (int i = 0; i < 100; i++)
        for (int j = 0; j < 100; j++)
            if (
                fn(i, j) != fnUnroll(i, j) ||
                fn(i, j) != fnUnrollNoMult(i, j)
            ) {
                std::cout << "Not equal: " << i << ", " << j << std::endl;
                return 1;
            }

    // Benchmark
    using namespace std::chrono;
    {
        auto start = high_resolution_clock::now();
        int result = fn(a, b);
        auto stop  = high_resolution_clock::now();
        std::cout << "fn value:             " << result << std::endl;
        std::cout << "fn nanos:             " << duration_cast<nanoseconds>(stop - start).count() << std::endl;
    }
    {
        auto start = high_resolution_clock::now();
        int result = fnUnroll(a, b);
        auto stop  = high_resolution_clock::now();
        std::cout << "fnUnroll value:       " << result << std::endl;
        std::cout << "fnUnroll nanos:       " << duration_cast<nanoseconds>(stop - start).count() << std::endl;
    }
    {
        auto start = high_resolution_clock::now();
        int result = fnUnrollNoMult(a, b);
        auto stop  = high_resolution_clock::now();
        std::cout << "fnUnrollNoMult value: " << result << std::endl;
        std::cout << "fnUnrollNoMult nanos: " << duration_cast<nanoseconds>(stop - start).count() << std::endl;
    }

    return 0;
}

下面的程序需要两个参数,分别代表ab。下面我将程序编译为g++ -std=c++14 foo.cpp -O3,并得到一些a 值的这些结果:

./a.out 1 2
fn value:             60
fn nanos:             373
fnUnroll value:       60
fnUnroll nanos:       209
fnUnrollNoMult value: 60
fnUnrollNoMult nanos: 157
./a.out 1000 2
fn value:             -267004960
fn nanos:             3509
fnUnroll value:       -267004960
fnUnroll nanos:       2820
fnUnrollNoMult value: -267004960
fnUnrollNoMult nanos: 1568
./a.out 1000000 2
fn value:             -619707648
fn nanos:             3137685
fnUnroll value:       -619707648
fnUnroll nanos:       2387840
fnUnrollNoMult value: -619707648
fnUnrollNoMult nanos: 1220519

【讨论】:

  • 令人印象深刻的工作......但一个明显的过早优化案例:该函数计算 b 乘以平方和,包括 4*a。循环可以变成一个简单的表达式。
  • @chqrlie,表达式如何处理a = -1
  • 三元运算符怎么样?或者,为了可读性,单个 if 测试。
  • @chqrlie,我不认为我在听你说的话,你能更详细地解释一下吗?我以为你得到了类似 faivvy 所展示的东西,但现在我很困惑
  • 在 C 中是:C11 3.4.3: 示例:未定义行为的一个示例是整数溢出时的行为。定义了无符号算术运算的溢出,但 有符号 整数运算的溢出是未定义的行为。
猜你喜欢
  • 2022-11-23
  • 2023-04-11
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多