【问题标题】:Can I do better on filtering numpy array我可以在过滤 numpy 数组方面做得更好吗
【发布时间】:2021-03-18 11:17:23
【问题描述】:

我有一个有点做作的例子来进行 cytonize,我想要一个函数来:

  1. 接受任意长度的一维 numpy 数组(~100'000 ÷ 1'000'000 np.float64's)
  2. 对其进行一些过滤
  3. 以相同长度的新 [numpy?] 数组形式返回结果

代码和分析如下:

%%cython -a

from libc.stdlib cimport malloc, free
from cython cimport boundscheck, wraparound
import numpy as np

@boundscheck(False)
@wraparound(False)
def func_memview(double[:] arr):
    cdef:
        int N = arr.shape[0], i
        double *out_ptr = <double *> malloc(N * sizeof(double))
        double[:] out = <double[:N]>out_ptr
    for i in range(1, N):
        if arr[i] > arr[i-1]:
            out[i] = arr[i]
        else:
            out[i] = 0.
    free(out_ptr)
    return np.asarray(out)

我的问题是我能做得更好吗?

【问题讨论】:

  • free(out_ptr); return np.asarray(out) 绝对可以做得更好!就我个人而言,我只会在函数的开头使用np.empty((N,)),并避免使用malloc
  • @DavidW 在顶部做out = np.empty(N,) 让我的时间差了 15 倍...可以作为答案的一个例子吗?
  • 您当前代码的问题在于outout_ptr 的内存视图。如果你删除out_ptr,那么你也不能使用out。在短期内你可能会侥幸逃脱,但最终记忆会被覆盖。我认为(但我不是 100% 确定)np.asarray(out) 没有制作另一个副本,但始终是相同(无效)内存的另一个视图。 out = np.empty(N,) 可能会更慢,但它确实有效!我现在没有一个简单的方法来测试一个完整的答案,但是如果没有其他人先给出一个,我稍后会写一个。
  • @DavidW 如果你能展示一些可以改进时间的代码,我真的很感激。
  • 还有一个小问题,如果你坚持使用c指针版本,你需要检查malloc的结果以确保它不为NULL,如果内存分配失败就会发生这种情况(例如,您的 RAM 碎片化或没有足够的空间来分配数组)。

标签: python numpy cython


【解决方案1】:

正如 DavidW 所指出的,您的代码在内存管理方面存在一些问题,最好直接使用 numpy-array:

%%cython

from cython cimport boundscheck, wraparound
import numpy as np

@boundscheck(False)
@wraparound(False)
def func_memview_correct(double[:] arr):
    cdef:
        int N = arr.shape[0], i
        double[:] out = np.empty(N)
    for i in range(1, N):
        if arr[i] > arr[i-1]:
            out[i] = arr[i]
        else:
            out[i] = 0.0
    return np.asarray(out)

它的速度与错误的原始版本差不多:

import numpy as np
np.random.seed(0)
k= np.random.rand(5*10**7)

%timeit func_memview(k)          # 413 ms ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit func_memview_correct(k)  # 412 ms ± 15.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

问题是如何使这段代码更快?最明显的选择是

  1. 并行化。
  2. 使用矢量化/SIMD 指令。

众所周知,很难确保 Cython 生成的 C 代码被矢量化,例如,请参见 SO-post。对于许多编译器来说,有必要使用连续内存视图来改善这种情况,即:

%%cython -c=/O3

from cython cimport boundscheck, wraparound
import numpy as np

@boundscheck(False)
@wraparound(False)
def func_memview_correct_cont(double[::1] arr):  // <---- HERE
    cdef:
        int N = arr.shape[0], i
        double[::1] out = np.empty(N)   // <--- HERE
    for i in range(1, N):
        if arr[i] > arr[i-1]:
            out[i] = arr[i]
        else:
            out[i] = 0.0
    return np.asarray(out)

在我的机器上并没有快多少

%timeit func_memview_correct_cont(k)  # 402 ms ± 11.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

其他编译器可能会做得更好。但是,我经常看到 gcc 和 msvc 努力为典型的过滤代码生成最佳汇编程序(例如,参见 SO-question)。 Clang 在这方面要好得多,所以最简单的解决方案可能是使用numba

import numba as nb

@nb.njit
def nb_func(arr):
    N = arr.shape[0]
    out = np.empty(N)
    for i in range(1, N):
        if arr[i] > arr[i-1]:
            out[i] = arr[i]
        else:
            out[i] = 0.0
    return out

它的性能几乎是 cython 代码的 3 倍:

%timeit nb_func(k)  # 151 ms ± 2.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

使用prange 很容易并行化 numba 版本,但胜利并不多:并行化版本在我的机器上运行时间为 116 毫秒。


总结一下:对于此类任务,我的建议是使用 numba。使用 cython 比较棘手,最终性能将取决于后台使用的编译器。

【讨论】:

  • 感谢您的回答!您能否澄清一下原始实现有哪些问题以及为什么称其为错误?在我看来,我从 Smith,2015 年的“Cython:Python 程序员指南”中逐字复制了动态分配的数组,从概念上讲,我似乎都使用缓冲区,手动(我的)和 numpy(你的)管理内存。我的是“故障”。为什么?
  • @SergeyBushmanov 不确定我能不能比stackoverflow.com/questions/66690006/… 更好地释放out_ptr,它的所有视图(out,返回的numpy-array)都变得无效。你能做的就是这个stackoverflow.com/a/60856020/5769463
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2011-10-13
  • 2011-09-20
  • 1970-01-01
  • 1970-01-01
  • 2013-02-23
相关资源
最近更新 更多