【问题标题】:Why is a naïve C++ matrix multiplication 100 times slower than BLAS?为什么天真的 C++ 矩阵乘法比 BLAS 慢 100 倍?
【发布时间】:2013-01-10 00:41:10
【问题描述】:

我正在研究大型矩阵乘法并运行以下实验以形成基线测试:

  1. 从标准标准(0 均值,1 标准差)随机生成两个 4096x4096 矩阵 X、Y。
  2. Z = X*Y
  3. 对 Z 的元素求和(以确保它们被访问)并输出。

这里是简单的 C++ 实现:

#include <iostream>
#include <algorithm>

using namespace std;

int main()
{
    constexpr size_t dim = 4096;

    float* x = new float[dim*dim];
    float* y = new float[dim*dim];
    float* z = new float[dim*dim];

    random_device rd;
    mt19937 gen(rd());
    normal_distribution<float> dist(0, 1);

    for (size_t i = 0; i < dim*dim; i++)
    {
        x[i] = dist(gen);
        y[i] = dist(gen);
    }

    for (size_t row = 0; row < dim; row++)
        for (size_t col = 0; col < dim; col++)
        {
            float acc = 0;

            for (size_t k = 0; k < dim; k++)
                acc += x[row*dim + k] * y[k*dim + col];

            z[row*dim + col] = acc;
        }

    float t = 0;

    for (size_t i = 0; i < dim*dim; i++)
        t += z[i];

    cout << t << endl;

    delete x;
    delete y;
    delete z;
}

编译运行:

$ g++ -std=gnu++11 -O3 test.cpp
$ time ./a.out

这是 Octave/matlab 的实现:

X = stdnormal_rnd(4096, 4096);
Y = stdnormal_rnd(4096, 4096);
Z = X*Y;
sum(sum(Z))

运行:

$ time octave < test.octave

引擎盖下的 Octave 正在使用 BLAS(我假设是 sgemm 函数)

硬件是 Linux x86-64 上的 i7 3930X,具有 24 GB 内存。 BLAS 似乎使用了两个内核。也许是超线程对?

我发现在-O3 上使用 GCC 4.7 编译的 C++ 版本需要 9 分钟才能执行:

real    9m2.126s
user    9m0.302s
sys         0m0.052s

八度音阶版本耗时 6 秒:

real    0m5.985s
user    0m10.881s
sys         0m0.144s

我知道 BLAS 已针对所有地狱进行了优化,而朴素的算法完全忽略了缓存等,但严重的是 - 90 次?

谁能解释这个区别? BLAS 实现的架构到底是什么?我看到它正在使用 Fortran,但是 CPU 级别发生了什么?它使用什么算法?它是如何使用 CPU 缓存的?它调用了哪些 x86-64 机器指令? (它是否使用了像 AVX 这样的高级 CPU 功能?)它从哪里获得这种额外的速度?

C++ 算法的哪些关键优化可以使其与 BLAS 版本相提并论?

我在 gdb 下运行 octave 并在计算中途停止了几次。它已经启动了第二个线程,这里是堆栈(所有停止看起来都相似):

(gdb) thread 1
#0  0x00007ffff6e17148 in pthread_join () from /lib/x86_64-linux-gnu/libpthread.so.0
#1  0x00007ffff1626721 in ATL_join_tree () from /usr/lib/libblas.so.3
#2  0x00007ffff1626702 in ATL_join_tree () from /usr/lib/libblas.so.3
#3  0x00007ffff15ae357 in ATL_dptgemm () from /usr/lib/libblas.so.3
#4  0x00007ffff1384b59 in atl_f77wrap_dgemm_ () from /usr/lib/libblas.so.3
#5  0x00007ffff193effa in dgemm_ () from /usr/lib/libblas.so.3
#6  0x00007ffff6049727 in xgemm(Matrix const&, Matrix const&, blas_trans_type, blas_trans_type) () from /usr/lib/x86_64-linux-gnu/liboctave.so.1
#7  0x00007ffff6049954 in operator*(Matrix const&, Matrix const&) () from /usr/lib/x86_64-linux-gnu/liboctave.so.1
#8  0x00007ffff7839e4e in ?? () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
#9  0x00007ffff765a93a in do_binary_op(octave_value::binary_op, octave_value const&, octave_value const&) () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
#10 0x00007ffff76c4190 in tree_binary_expression::rvalue1(int) () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
#11 0x00007ffff76c33a5 in tree_simple_assignment::rvalue1(int) () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
#12 0x00007ffff76d0864 in tree_evaluator::visit_statement(tree_statement&) () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
#13 0x00007ffff76cffae in tree_evaluator::visit_statement_list(tree_statement_list&) () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
#14 0x00007ffff757f6d4 in main_loop() () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
#15 0x00007ffff7527abf in octave_main () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1

(gdb) thread 2
#0  0x00007ffff14ba4df in ATL_dJIK56x56x56TN56x56x0_a1_b1 () from /usr/lib/libblas.so.3
(gdb) bt
#0  0x00007ffff14ba4df in ATL_dJIK56x56x56TN56x56x0_a1_b1 () from /usr/lib/libblas.so.3
#1  0x00007ffff15a5fd7 in ATL_dmmIJK2 () from /usr/lib/libblas.so.3
#2  0x00007ffff15a6ae4 in ATL_dmmIJK () from /usr/lib/libblas.so.3
#3  0x00007ffff1518e65 in ATL_dgemm () from /usr/lib/libblas.so.3
#4  0x00007ffff15adf7a in ATL_dptgemm0 () from /usr/lib/libblas.so.3
#5  0x00007ffff6e15e9a in start_thread () from /lib/x86_64-linux-gnu/libpthread.so.0
#6  0x00007ffff6b41cbd in clone () from /lib/x86_64-linux-gnu/libc.so.6
#7  0x0000000000000000 in ?? ()

它按预期调用 BLAS gemm

第一个线程似乎正在加入第二个线程,所以我不确定这两个线程是否占观察到的 200% 的 CPU 使用率。

ATL_dgemm libblas.so.3 是哪个库,它的代码在哪里?

$ ls -al /usr/lib/libblas.so.3
/usr/lib/libblas.so.3 -> /etc/alternatives/libblas.so.3

$ ls -al /etc/alternatives/libblas.so.3
/etc/alternatives/libblas.so.3 -> /usr/lib/atlas-base/atlas/libblas.so.3

$ ls -al /usr/lib/atlas-base/atlas/libblas.so.3
/usr/lib/atlas-base/atlas/libblas.so.3 -> libblas.so.3.0

$ ls -al /usr/lib/atlas-base/atlas/libblas.so.3.0
/usr/lib/atlas-base/atlas/libblas.so.3.0

$ dpkg -S /usr/lib/atlas-base/atlas/libblas.so.3.0
libatlas3-base: /usr/lib/atlas-base/atlas/libblas.so.3.0

$ apt-get source libatlas3-base

是 ATLAS 3.8.4

以下是我后来实现的优化:

使用平铺方法,我将 64x64 的 X、Y 和 Z 块预加载到单独的数组中。

更改每个块的计算,使内部循环如下所示:

for (size_t tcol = 0; tcol < block_width; tcol++)
    bufz[trow][tcol] += B * bufy[tk][tcol];

这允许 GCC 自动向量化为 SIMD 指令,还允许指令级并行(我认为)。

开启march=corei7-avx。这获得了 30% 的额外速度,但这是作弊,因为我认为 BLAS 库是预先构建的。

代码如下:

#include <iostream>
#include <algorithm>

using namespace std;

constexpr size_t dim = 4096;
constexpr size_t block_width = 64;
constexpr size_t num_blocks = dim / block_width;

double X[dim][dim], Y[dim][dim], Z[dim][dim];

double bufx[block_width][block_width];
double bufy[block_width][block_width];
double bufz[block_width][block_width];

void calc_block()
{
    for (size_t trow = 0; trow < block_width; trow++)
        for (size_t tk = 0; tk < block_width; tk++)
        {
            double B = bufx[trow][tk];

            for (size_t tcol = 0; tcol < block_width; tcol++)
                bufz[trow][tcol] += B * bufy[tk][tcol];
        }
}

int main()
{
    random_device rd;
    mt19937 gen(rd());
    normal_distribution<double> dist(0, 1);

    for (size_t row = 0; row < dim; row++)
        for (size_t col = 0; col < dim; col++)
        {
            X[row][col] = dist(gen);
            Y[row][col] = dist(gen);
            Z[row][col] = 0;
        }

    for (size_t block_row = 0; block_row < num_blocks; block_row++)
        for (size_t block_col = 0; block_col < num_blocks; block_col++)
        {
            for (size_t trow = 0; trow < block_width; trow++)
                for (size_t tcol = 0; tcol < block_width; tcol++)
                    bufz[trow][tcol] = 0;

            for (size_t block_k = 0; block_k < num_blocks; block_k++)
            {
                for (size_t trow = 0; trow < block_width; trow++)
                    for (size_t tcol = 0; tcol < block_width; tcol++)
                    {
                        bufx[trow][tcol] = X[block_row*block_width + trow][block_k*block_width + tcol];
                        bufy[trow][tcol] = Y[block_k*block_width + trow][block_col*block_width + tcol];
                    }

                calc_block();
            }

            for (size_t trow = 0; trow < block_width; trow++)
                for (size_t tcol = 0; tcol < block_width; tcol++)
                    Z[block_row*block_width + trow][block_col*block_width + tcol] = bufz[trow][tcol];

        }

    double t = 0;

    for (size_t row = 0; row < dim; row++)
        for (size_t col = 0; col < dim; col++)
            t += Z[row][col];

    cout << t << endl;
}

所有的动作都在 calc_block 函数中 - 超过 90% 的时间都花在它上面。

新的时间是:

real    0m17.370s
user    0m17.213s
sys 0m0.092s

哪个更接近。

calc_block函数的反编译如下:

0000000000401460 <_Z10calc_blockv>:
  401460:   b8 e0 21 60 00          mov    $0x6021e0,%eax
  401465:   41 b8 e0 23 61 00       mov    $0x6123e0,%r8d
  40146b:   31 ff                   xor    %edi,%edi
  40146d:   49 29 c0                sub    %rax,%r8
  401470:   49 8d 34 00             lea    (%r8,%rax,1),%rsi
  401474:   48 89 f9                mov    %rdi,%rcx
  401477:   ba e0 a1 60 00          mov    $0x60a1e0,%edx
  40147c:   48 c1 e1 09             shl    $0x9,%rcx
  401480:   48 81 c1 e0 21 61 00    add    $0x6121e0,%rcx
  401487:   66 0f 1f 84 00 00 00    nopw   0x0(%rax,%rax,1)
  40148e:   00 00 
  401490:   c4 e2 7d 19 01          vbroadcastsd (%rcx),%ymm0
  401495:   48 83 c1 08             add    $0x8,%rcx
  401499:   c5 fd 59 0a             vmulpd (%rdx),%ymm0,%ymm1
  40149d:   c5 f5 58 08             vaddpd (%rax),%ymm1,%ymm1
  4014a1:   c5 fd 29 08             vmovapd %ymm1,(%rax)
  4014a5:   c5 fd 59 4a 20          vmulpd 0x20(%rdx),%ymm0,%ymm1
  4014aa:   c5 f5 58 48 20          vaddpd 0x20(%rax),%ymm1,%ymm1
  4014af:   c5 fd 29 48 20          vmovapd %ymm1,0x20(%rax)
  4014b4:   c5 fd 59 4a 40          vmulpd 0x40(%rdx),%ymm0,%ymm1
  4014b9:   c5 f5 58 48 40          vaddpd 0x40(%rax),%ymm1,%ymm1
  4014be:   c5 fd 29 48 40          vmovapd %ymm1,0x40(%rax)
  4014c3:   c5 fd 59 4a 60          vmulpd 0x60(%rdx),%ymm0,%ymm1
  4014c8:   c5 f5 58 48 60          vaddpd 0x60(%rax),%ymm1,%ymm1
  4014cd:   c5 fd 29 48 60          vmovapd %ymm1,0x60(%rax)
  4014d2:   c5 fd 59 8a 80 00 00    vmulpd 0x80(%rdx),%ymm0,%ymm1
  4014d9:   00 
  4014da:   c5 f5 58 88 80 00 00    vaddpd 0x80(%rax),%ymm1,%ymm1
  4014e1:   00 
  4014e2:   c5 fd 29 88 80 00 00    vmovapd %ymm1,0x80(%rax)
  4014e9:   00 
  4014ea:   c5 fd 59 8a a0 00 00    vmulpd 0xa0(%rdx),%ymm0,%ymm1
  4014f1:   00 
  4014f2:   c5 f5 58 88 a0 00 00    vaddpd 0xa0(%rax),%ymm1,%ymm1
  4014f9:   00 
  4014fa:   c5 fd 29 88 a0 00 00    vmovapd %ymm1,0xa0(%rax)
  401501:   00 
  401502:   c5 fd 59 8a c0 00 00    vmulpd 0xc0(%rdx),%ymm0,%ymm1
  401509:   00 
  40150a:   c5 f5 58 88 c0 00 00    vaddpd 0xc0(%rax),%ymm1,%ymm1
  401511:   00 
  401512:   c5 fd 29 88 c0 00 00    vmovapd %ymm1,0xc0(%rax)
  401519:   00 
  40151a:   c5 fd 59 8a e0 00 00    vmulpd 0xe0(%rdx),%ymm0,%ymm1
  401521:   00 
  401522:   c5 f5 58 88 e0 00 00    vaddpd 0xe0(%rax),%ymm1,%ymm1
  401529:   00 
  40152a:   c5 fd 29 88 e0 00 00    vmovapd %ymm1,0xe0(%rax)
  401531:   00 
  401532:   c5 fd 59 8a 00 01 00    vmulpd 0x100(%rdx),%ymm0,%ymm1
  401539:   00 
  40153a:   c5 f5 58 88 00 01 00    vaddpd 0x100(%rax),%ymm1,%ymm1
  401541:   00 
  401542:   c5 fd 29 88 00 01 00    vmovapd %ymm1,0x100(%rax)
  401549:   00 
  40154a:   c5 fd 59 8a 20 01 00    vmulpd 0x120(%rdx),%ymm0,%ymm1
  401551:   00 
  401552:   c5 f5 58 88 20 01 00    vaddpd 0x120(%rax),%ymm1,%ymm1
  401559:   00 
  40155a:   c5 fd 29 88 20 01 00    vmovapd %ymm1,0x120(%rax)
  401561:   00 
  401562:   c5 fd 59 8a 40 01 00    vmulpd 0x140(%rdx),%ymm0,%ymm1
  401569:   00 
  40156a:   c5 f5 58 88 40 01 00    vaddpd 0x140(%rax),%ymm1,%ymm1
  401571:   00 
  401572:   c5 fd 29 88 40 01 00    vmovapd %ymm1,0x140(%rax)
  401579:   00 
  40157a:   c5 fd 59 8a 60 01 00    vmulpd 0x160(%rdx),%ymm0,%ymm1
  401581:   00 
  401582:   c5 f5 58 88 60 01 00    vaddpd 0x160(%rax),%ymm1,%ymm1
  401589:   00 
  40158a:   c5 fd 29 88 60 01 00    vmovapd %ymm1,0x160(%rax)
  401591:   00 
  401592:   c5 fd 59 8a 80 01 00    vmulpd 0x180(%rdx),%ymm0,%ymm1
  401599:   00 
  40159a:   c5 f5 58 88 80 01 00    vaddpd 0x180(%rax),%ymm1,%ymm1
  4015a1:   00 
  4015a2:   c5 fd 29 88 80 01 00    vmovapd %ymm1,0x180(%rax)
  4015a9:   00 
  4015aa:   c5 fd 59 8a a0 01 00    vmulpd 0x1a0(%rdx),%ymm0,%ymm1
  4015b1:   00 
  4015b2:   c5 f5 58 88 a0 01 00    vaddpd 0x1a0(%rax),%ymm1,%ymm1
  4015b9:   00 
  4015ba:   c5 fd 29 88 a0 01 00    vmovapd %ymm1,0x1a0(%rax)
  4015c1:   00 
  4015c2:   c5 fd 59 8a c0 01 00    vmulpd 0x1c0(%rdx),%ymm0,%ymm1
  4015c9:   00 
  4015ca:   c5 f5 58 88 c0 01 00    vaddpd 0x1c0(%rax),%ymm1,%ymm1
  4015d1:   00 
  4015d2:   c5 fd 29 88 c0 01 00    vmovapd %ymm1,0x1c0(%rax)
  4015d9:   00 
  4015da:   c5 fd 59 82 e0 01 00    vmulpd 0x1e0(%rdx),%ymm0,%ymm0
  4015e1:   00 
  4015e2:   c5 fd 58 80 e0 01 00    vaddpd 0x1e0(%rax),%ymm0,%ymm0
  4015e9:   00 
  4015ea:   48 81 c2 00 02 00 00    add    $0x200,%rdx
  4015f1:   48 39 ce                cmp    %rcx,%rsi
  4015f4:   c5 fd 29 80 e0 01 00    vmovapd %ymm0,0x1e0(%rax)
  4015fb:   00 
  4015fc:   0f 85 8e fe ff ff       jne    401490 <_Z10calc_blockv+0x30>
  401602:   48 83 c7 01             add    $0x1,%rdi
  401606:   48 05 00 02 00 00       add    $0x200,%rax
  40160c:   48 83 ff 40             cmp    $0x40,%rdi
  401610:   0f 85 5a fe ff ff       jne    401470 <_Z10calc_blockv+0x10>
  401616:   c5 f8 77                vzeroupper 
  401619:   c3                      retq   
  40161a:   66 0f 1f 44 00 00       nopw   0x0(%rax,%rax,1)

【问题讨论】:

  • 欢迎来到缓存的世界。他们可以做出惊人的事情......
  • 150 倍的加速比您可能能够摆脱更好的缓存行为和针对您的问题的其他优化。
  • 您也在为 32m 随机数的生成计时!我不知道这些 RNG 有多快,但要真正了解问题可能出在哪里,您应该只计算乘法时间!例如,您可以使用一些低级 API 或 boost::auto_cpu_timer 来测量程序中的时间。
  • @us2012:我已经对其进行了测试,以确保在这两种情况下乘法都占主导地位。
  • acc += x[row*dim + k] * y[k*dim + col]; z[row*dim + col] = acc; 太糟糕了。

标签: c++ linux matlab c++11 matrix-multiplication


【解决方案1】:

我不知道这些信息有多可靠,但Wikipedia 说 BLAS 使用 Strassen 的算法来处理大矩阵。你的确实很大。这大约是 O(n^2.807),比你的 O(n^3) 天真的算法要好。

【讨论】:

  • 4096^2.807 = 1.38005e+10, 4096^3 = 6.87195e+10。此外,Strassen 的恒定开销更高。这仅解释了 OP 看到的一小部分差异。
  • 应该是 (4096^2)^2.807 vs (4096^2)^3。
  • @DeadMG:我认为他们使用的 n 是矩阵宽度,而不是输入大小。天真的算法显然是 O(width^3) 而不是 O(width^6)
  • @us2012 是的,如果常量相同,它应该快 5 倍左右。但他们从来都不是。不同的算法也会产生不同的优化可能性(管道、缓存、内存对齐)。
  • @Csq 当然。我并不是说你的答案是错误的,只是说这只是全局的一小部分。
【解决方案2】:

大约一半的差异用于算法改进。 (4096*4096)^3 是算法的复杂度,即 4.7x10^21,(4096*4096)^2.807 是 1x10^20。这是大约 47 倍的差异。

另外 2 倍将通过更智能地使用缓存、SSE 指令和其他此类低级优化来解决。

编辑:我撒谎,n 是宽度,而不是宽度^2。该算法实际上只占大约 4 倍,所以还有大约 22 倍要走。线程、缓存和 SSE 指令很可能会解决这些问题。

【讨论】:

  • 我认为你的数字是错误的。 nieve 算法是 ~4096^3。看看循环。它是 4096 范围内的三个嵌套 for 循环。
  • 另外,我们不知道在 naive 和 stassen 之间可能存在差异的常数因素(实际上相差多少),因此我们无法在此分辨率下进行比较。不过,Strassen 算法值得研究。
  • @user1131467:Wikipedia page 表示 Strassen 在 n = 100 附近比直接乘法略好。这表明因子是 1003 / 1002.807 = 2.4。显然,这将根据处理器型号、与缓存效果交互的矩阵大小等而有很大差异。简单的推断表明,在 n = 4096 时,Strassen 的性能大约是直接乘法的两倍。
  • 取决于现代 CPU 所涉及的 SSE 类型、缓存效果和线程注意事项。即使在经过测试几年后也可能会改变答案。
  • 刚刚注意到我旧评论中的格式问题。 (“**”的解释是否改变为粗体?)这是一个更新:@user1131467:Wikipedia page 说 Strassen 比 n = 100 附近的直接乘法略好。这表明因子是 100^3 / 100^ 2.807 = 2.4。显然,这将根据处理器型号、与缓存效果交互的矩阵大小等而有很大差异。简单的推断表明,在 n = 4096 时,Strassen 的性能大约是直接乘法的两倍。
【解决方案3】:

以下是导致您的代码与 BLAS 之间的性能差异的三个因素(以及关于 Strassen 算法的注释)。

在你的内部循环中,在k,你有y[k*dim + col]。由于内存缓存的排列方式,具有相同dimcolk 的连续值映射到相同的缓存集。缓存的结构方式是,每个内存地址都有一个缓存集,当它在缓存中时,它的内容必须被保存。每个缓存集都有几行(通常为四行),每一行都可以保存映射到该特定缓存集的任何内存地址。

因为您的内部循环以这种方式遍历y,所以每次它使用来自y 的元素时,它必须将该元素的内存加载到与前一次迭代相同的集合中。这会强制驱逐集合中的先前缓存行之一。然后,在col 循环的下一次迭代中,y 的所有元素都已从缓存中逐出,因此必须重新加载它们。

因此,每次循环加载y 的元素时,它必须从内存中加载,这需要很多 CPU 周期。

高性能代码可以通过两种方式避免这种情况。一,它将工作分成更小的块。行和列被划分为更小的尺寸,并通过更短的循环进行处理,这些循环能够使用高速缓存行中的所有元素,并在每个元素进入下一个块之前多次使用它们。因此,大多数对x 元素和y 元素的引用来自缓存,通常在单个处理器周期中。第二,在某些情况下,代码会将数据从矩阵的一列(由于几何形状而颠簸缓存)复制到临时缓冲区的行中(避免颠簸)。这再次允许从缓存而不是从内存中提供大部分内存引用。

另一个因素是使用单指令多数据 (SIMD) 功能。许多现代处理器的指令在一条指令中加载多个元素(典型的是四个 float 元素,但现在有些是八个),存储多个元素,添加多个元素(例如,对于这四个中的每一个,将其添加到相应的一个这四个),乘以多个元素,等等。只要您能够安排工作以使用这些指令,只需使用这些指令即可立即使您的代码速度提高四倍。

这些指令在标准 C 中不能直接访问。现在一些优化器会尽可能地尝试使用这些指令,但是这种优化很困难,并且从中获得太多好处并不常见。许多编译器提供了语言的扩展,可以访问这些指令。就个人而言,我通常更喜欢用汇编语言编写来使用 SIMD。

另一个因素是在处理器上使用指令级并行执行功能。请注意,在您的内部循环中,acc 已更新。在上一次迭代完成更新acc 之前,下一次迭代无法添加到acc。相反,高性能代码将保持多个总和并行运行(甚至是多个 SIMD 总和)。这样做的结果是,在执行一个总和的加法时,将开始另一个总和的加法。在当今的处理器上,一次支持四个或更多的浮点运算是很常见的。如所写,您的代码根本无法做到这一点。一些编译器会尝试通过重新排列循环来优化代码,但这要求编译器能够看到特定循环的迭代是相互独立的,或者可以与另一个循环交换,等等。

使用缓存有效地提供十倍的性能提升,SIMD 提供另外四倍,指令级并行提供另外四倍,总共提供 160 倍,这是完全可行的。

这是基于this Wikipedia page 对 Strassen 算法效果的一个非常粗略的估计。维基百科页面说 Strassen 在 n = 100 左右比直接乘法略好。这表明执行时间的常数因子之比为 1003 / 1002.807 ≈ 2.4 .显然,这将因处理器型号、与缓存效果交互的矩阵大小等而有很大差异。然而,简单的推断表明,在 n = 4096 ((4096/100)3-2.807 ≈ 2.05) 时,Strassen 的性能大约是直接乘法的两倍。同样,这只是一个大概的估计。

至于后面的优化,内循环考虑这段代码:

bufz[trow][tcol] += B * bufy[tk][tcol];

一个潜在的问题是bufz 通常会与bufy 重叠。由于您对bufzbufy 使用全局定义,因此编译器可能知道它们在这种情况下不会重叠。但是,如果您将此代码移动到以bufzbufy 作为参数传递的子例程中,特别是如果您在单独的源文件中编译该子例程,则编译器不太可能知道bufz 和@ 987654346@ 不重叠。在这种情况下,编译器无法对代码进行向量化或重新排序,因为此迭代中的 bufz[trow][tcol] 可能与另一迭代中的 bufy[tk][tcol] 相同。

即使编译器可以看到在当前源模块中使用不同的bufzbufy 调用子例程,如果例程具有extern 链接(默认),那么编译器必须允许要从外部模块调用的例程,因此如果bufzbufy 重叠,它必须生成可以正常工作的代码。 (编译器可以处理的一种方法是生成两个版本的例程,一个从外部模块调用,一个从当前正在编译的模块调用。是否这样做取决于您的编译器,优化开关,等等。)如果您将例程声明为static,则编译器知道它不能从外部模块调用(除非您获取其地址并且该地址有可能传递到当前模块之外)。

另一个潜在的问题是,即使编译器将此代码向量化,它也不一定会为您执行的处理器生成最佳代码。查看生成的汇编代码,编译器似乎只重复使用%ymm1。一遍又一遍,它将内存中的一个值乘以%ymm1,将内存中的一个值添加到%ymm1,并将%ymm1 中的一个值存储到内存中。这有两个问题。

第一,您不希望这些部分和经常存储到内存中。您希望将许多加法累积到一个寄存器中,并且该寄存器只会很少写入内存。说服编译器这样做可能需要重写代码以明确将部分和保存在临时对象中并在循环完成后将它们写入内存。

第二,这些指令名义上是串行相关的。在乘法完成之前添加无法开始,并且在添加完成之前存储无法写入内存。 Core i7 具有强大的乱序执行能力。因此,虽然它有等待开始执行的加法,但它稍后会在指令流中查看乘法并启动它。 (即使该乘法也使用%ymm1,处理器会即时重新映射寄存器,以便它使用不同的内部寄存器来执行此乘法。)即使您的代码充满了连续的依赖关系,处理器也会尝试执行几个立即指示。但是,有很多事情会干扰这一点。您可能会用完处理器用于重命名的内部寄存器。您使用的内存地址可能会遇到错误的冲突。 (处理器查看十几个内存地址的低位,以查看该地址是否与它尝试从较早指令加载或存储的另一个地址相同。如果位相等,则处理器具有延迟当前加载或存储,直到它可以验证整个地址不同。这种延迟可能会比当前加载或存储更多。)因此,最好有完全独立的指令。

这也是我更喜欢在汇编中编写高性能代码的另一个原因。要在 C 中做到这一点,您必须说服编译器给您这样的指令,方法是编写一些您自己的 SIMD 代码(使用它们的语言扩展)和手动展开循环(编写多个迭代)。

在复制入和复制出缓冲区时,可能会出现类似问题。但是,您报告 90% 的时间都花在了calc_block,所以我没有仔细研究过。

【讨论】:

  • 这个答案很好,因为它回答了算法性能中许多难以理解的部分,但正如其他答案中提到的,天真的乘法比 strassen 或其他超快算法慢得多对于大型矩阵。除非我读错了。这些算法因矩阵的大小而有很大差异,因此它甚至可能不是一个因素。
  • 是的,在我的异构并行编程课程中,我们使用 GPU 上的 cuda 实现了分块矩阵乘法,其中涉及优化 SIMD 并使用块缓存(opencl 中的“工作组”)加载矩阵进入缓存,然后本地化工作。不过,在单线程环境中看到此类技术的 90 倍改进是非常令人惊讶的。我会调查并尝试支持您的主张。
  • 如果我们转置 y 应该会立即改善,因为我们现在可以按列顺序访问它。
  • 好点!我要补充一点,以测试缓存因素对性能的影响,y[k*dim + col] 可以简单地更改为y[col*dim + k],只是为了看看它离 BLAS 的性能还有多远。
  • @MikaelPersson:废话。 real 1m3.350s user 1m2.972s sys 0m0.148s。 9 倍改进。感人的。 `
【解决方案4】:

Strassen 算法与朴素算法相比有两个优点:

  1. 正如其他答案正确指出的那样,就操作数量而言,时间复杂度更高;
  2. 这是一个cache-oblivious algorithmThe difference in number of cache misses 的顺序是 B*M½,其中 B 是缓存行大小,M 是缓存大小。

我认为第二点是造成您所经历的减速的主要原因。如果您在 Linux 下运行应用程序,我建议您使用 perf 工具运行它们,该工具会告诉您程序正在经历多少缓存未命中。

【讨论】:

    【解决方案5】:

    这是一个相当复杂的话题,Eric 在上面的帖子中回答得很好。我只想指出这个方向的有用参考,第 84 页:

    http://www.rrze.fau.de/dienste/arbeiten-rechnen/hpc/HPC4SE/

    这建议在阻塞之上进行“循环展开和堵塞”。

    谁能解释这个区别?

    一般的解释是,操作数/数据数之比为O(N^3)/O(N^2)。因此,矩阵-矩阵乘法是一种缓存绑定算法,这意味着对于大矩阵大小,您不会遇到常见的内存带宽瓶颈。 如果代码优化得当,您可以获得高达 90% 的 CPU 峰值性能。因此,正如您所观察到的,Eric 阐述的优化潜力是巨大的。实际上,看到性能最好的代码并用另一个编译器编译你的最终程序会很有趣(英特尔通常吹嘘自己是最好的)。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2012-07-14
      • 2012-06-22
      • 2018-04-11
      • 1970-01-01
      • 2020-07-13
      • 2021-08-05
      • 2019-03-16
      • 1970-01-01
      相关资源
      最近更新 更多