【问题标题】:Speed up Cython implementation of dot product multiplication加快点积乘法的 Cython 实现
【发布时间】:2022-01-19 16:07:12
【问题描述】:

我正在尝试通过在点积操作np.dot(a,b) 上超越 Numpy 来学习 cython。但我的实现慢了大约 4 倍。

所以,这是我的 hello.pyx 文件 cython 实现:

cimport numpy as cnp
cnp.import_array()

cpdef double dot_product(double[::1] vect1, double[::1] vect2):
    cdef int size = vect1.shape[0]
    cdef double result = 0
    cdef int i = 0
    while i < size:
        result += vect1[i] * vect2[i]
        i += 1
    return result

这是我的 .py 测试文件:

import timeit

setup = '''
import numpy as np
import hello

n = 10000
a = np.array([float(i) for i in range(n)])
b = np.array([i/2 for i in a])
'''
lf_code = 'res_lf = hello.dot_product(a, b)'
np_code = 'res_np = np.dot(a,b)'
n = 100
lf_time = timeit.timeit(lf_code, setup=setup, number=n) * 100
np_time = timeit.timeit(np_code, setup=setup, number=n) * 100

print(f'Lightning fast time: {lf_time}.')
print(f'Numpy time: {np_time}.')

控制台输出:

Lightning fast time: 0.12186000000156127.
Numpy time: 0.028800000001183435.

构建 hello.pyx 的命令:

python setup.py build_ext --inplace

setup.py 文件:

from distutils.core import Extension, setup
from Cython.Build import cythonize
import numpy as np

# define an extension that will be cythonized and compiled
ext = Extension(name="hello", sources=["hello.pyx"], include_dirs=[np.get_include()])
setup(ext_modules=cythonize(ext))

处理器: i7-7700T @ 2.90 GHz

【问题讨论】:

  • 你正在与OpenBLAS(或类似的)竞争:numpy 只包装这个计算。也许你的野心太大了。 OpenBLAS 和公司。具有 CPU 架构特定的内核、SIMD、多线程和汇编代码。
  • 我认为你不能比 numpy 做得更好,但挑战很好。 Numpy 使用向量指令,这些指令应该同时处理多个数字。顺便说一句,使用循环并不是进行矩阵乘法的最佳方式。我建议你看看 Strassen 和 Coppersmith 算法。如果你想比numpy做得更好,就用显卡; Numba cuda 可以毫不费力地进行并行计算。
  • 你不是第一个尝试失败的人:stackoverflow.com/q/10442365/5769463 cython 比 c 还要难,因为 cython 可以生成不易优化的 c 代码
  • 就其价值而言,比 OpenBLAS 慢 4 倍意味着您可能在不采取极端措施的情况下尽可能地做到了。
  • @Neofelis(和@ead)这不是矩阵乘法而是点积。因此,Strassen 和 Coppersmith 算法在这里不相关。此外,不,GPU 对于这样的任务不会更快,因为数组太小了。在 GPU 上执行内核的延迟通常高于这个时间(GPU 针对大工作负载进行了优化,显然不是低延迟)。更不用说向 GPU 传输数据的速度很慢,而且这种操作通常受内存限制(就像所有 BLAS 1 级原语一样)。

标签: python numpy performance cython


【解决方案1】:

问题主要来自与 Numpy(在大多数平台上默认使用 OpenBLAS)相比,缺少 SIMD 指令(由于边界检查和低效的默认编译器标志)。

要解决这个问题,您应该首先在 hello.pix 文件的开头添加以下行:

#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False

那么,你应该使用这个新的setup.py 文件:

from distutils.core import Extension, setup
from Cython.Build import cythonize
import numpy as np

# define an extension that will be cythonized and compiled
ext = Extension(name="hello", sources=["hello.pyx"], include_dirs=[np.get_include()], extra_compile_args=['-O3', '-mavx', '-ffast-math'])
setup(ext_modules=cythonize(ext))

请注意,标志取决于编译器。话虽如此,Clang 和 GCC 都支持它们(可能也支持 ICC)。 -O3 告诉 Clang 和 GCC 使用更积极的优化,例如代码的自动矢量化。 -mavx 告诉他们使用 AVX 指令集(仅在相对较新的 x86-64 处理器上可用)。 -ffast-math 告诉他们假设浮点数运算是关联的(事实并非如此),并且您只使用有限/基本数字(没有 NaN,也没有无穷大)。如果不满足上述假设,则程序可能会在运行时崩溃,因此小心此类标志。

请注意,OpenBLAS 会根据您的机器和 AFAIK 自动选择指令集,它不使用 -ffast-math,而是使用更安全(低级)的替代方案。


结果:

这是我机器上的结果:

Before optimization:
  Lightning fast time: 0.10018469997703505.
  Numpy time: 0.024747799989199848.

After (with GCC):
  Lightning fast time: 0.02865879996534204.
  Numpy time: 0.02456870001878997.

After (with Clang):
  Lightning fast time: 0.01965239998753532.
  Numpy time: 0.024799799984975834.

Clang 生成的代码在我的机器上比 Numpy 快


在引擎盖下

对我机器上处理器执行的汇编代码的分析表明,该代码仅使用慢速标量指令,包含不必要的边界检查,主要受result += ... 操作的限制(因为循环携带依赖)。

162e3:┌─→movsd  xmm0,QWORD PTR [rbx+rax*8]  # Load 1 item
162e8:│  mulsd  xmm0,QWORD PTR [rsi+rax*8]  # Load 1 item
162ed:│  addsd  xmm1,xmm0                   # Main bottleneck (accumulation)
162f1:│  cmp    rdi,rax
162f4:│↓ je     163f8                       # Bound checking conditional jump
162fa:│  cmp    rdx,rax
162fd:│↓ je     16308                       # Bound checking conditional jump
162ff:│  add    rax,0x1
16303:├──cmp    rcx,rax
16306:└──jne    162e3

优化后的结果是:

13720:┌─→vmovupd      ymm3,YMMWORD PTR [r13+rax*1+0x0]    # Load 4 items
13727:│  vmulpd       ymm0,ymm3,YMMWORD PTR [rcx+rax*1]   # Load 4 items
1372c:│  add          rax,0x20
13730:│  vaddpd       ymm1,ymm1,ymm0        # Still a bottleneck (but better)
13734:├──cmp          rdx,rax
13737:└──jne          13720

result += ... 操作仍然是优化版本中的瓶颈,但这要好得多,因为循环一次可以处理 4 个项目。要消除瓶颈,必须部分展开循环。但是,GCC(这是我机器上的默认编译器)无法正确执行此操作(即使要求使用 -funrol-loops(由于循环携带的依赖)。这就是为什么 OpenBLAS 应该比GCC 生成的代码。

希望 Clang 默认能够做到这一点。下面是 Clang 生成的代码:

59e0:┌─→vmovupd      ymm4,YMMWORD PTR [rax+rdi*8]       # load 4 items
59e5:│  vmovupd      ymm5,YMMWORD PTR [rax+rdi*8+0x20]  # load 4 items
59eb:│  vmovupd      ymm6,YMMWORD PTR [rax+rdi*8+0x40]  # load 4 items
59f1:│  vmovupd      ymm7,YMMWORD PTR [rax+rdi*8+0x60]  # load 4 items
59f7:│  vmulpd       ymm4,ymm4,YMMWORD PTR [rbx+rdi*8]
59fc:│  vaddpd       ymm0,ymm4,ymm0
5a00:│  vmulpd       ymm4,ymm5,YMMWORD PTR [rbx+rdi*8+0x20]
5a06:│  vaddpd       ymm1,ymm4,ymm1
5a0a:│  vmulpd       ymm4,ymm6,YMMWORD PTR [rbx+rdi*8+0x40]
5a10:│  vmulpd       ymm5,ymm7,YMMWORD PTR [rbx+rdi*8+0x60]
5a16:│  vaddpd       ymm2,ymm4,ymm2
5a1a:│  vaddpd       ymm3,ymm5,ymm3
5a1e:│  add          rdi,0x10
5a22:├──cmp          rsi,rdi
5a25:└──jne          59e0

代码不是最优的(因为 vaddpd 指令的延迟,它应该至少展开循环 6 次),但它非常好。

【讨论】:

  • “幕后”部分很棒。我对 ASM 知之甚少,因此学习总是很愉快。让我想忽略“编译器比你更好”的事情并尝试编写 ASM。
猜你喜欢
  • 2020-02-25
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2014-11-12
  • 2019-01-17
  • 2016-03-14
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多