【问题标题】:Why is NumPy sometimes slower than NumPy + plain Python loop?为什么 NumPy 有时比 NumPy + 普通 Python 循环慢?
【发布时间】:2019-06-28 12:25:45
【问题描述】:

这是基于this question 2018-10 提出的。

考虑以下代码。三个简单的函数来计算 NumPy 3D 数组 (1000 × 1000 × 1000) 中的非零元素。

import numpy as np

def f_1(arr):
    return np.sum(arr > 0)

def f_2(arr):
    ans = 0
    for val in range(arr.shape[0]):
        ans += np.sum(arr[val, :, :] > 0)
    return ans

def f_3(arr):
    return np.count_nonzero(arr)

if __name__ == '__main__':

    data = np.random.randint(0, 10, (1_000, 1_000, 1_000))
    print(f_1(data))
    print(f_2(data))
    print(f_3(data))

我机器上的运行时(Python 3.7.?、Windows 10、NumPy 1.16.?):

%timeit f_1(data)
1.73 s ± 21.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit f_2(data)
1.4 s ± 1.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit f_3(data)
2.38 s ± 956 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

所以,f_2() 的运行速度比 f_1()f_3() 快。但是,较小的data 并非如此。问题是——为什么会这样?是 NumPy、Python 还是其他?

【问题讨论】:

  • 您正在声明一个 8Gb 阵列,它不适合许多台式机或笔记本电脑,因此您正在隐式测试缓存行为。 np.int64 类型的 3D numpy 数组(1000 × 1000 × 1000 立方体)是 10^9 个元素,每个元素 8 个字节。因此,这将归结为 8Gb 整数数组是否适合/被缓存/等在您的特定机器上。由于它不适合许多台式机或笔记本电脑,因此如果您不想测试缓存,请改用 100 × 100 × 100 的尺寸,即 8Mb。
  • 另外,为了确定这一点,请在np.random.randint 调用之前设置随机种子。
  • @Barker 您似乎忽略了秒和毫秒之间的差异。

标签: python performance numpy


【解决方案1】:

这是由于内存访问和缓存造成的。这些函数中的每一个都在做两件事,以第一个代码为例:

np.sum(arr > 0)

它首先进行比较以查找arr 大于零(或非零,因为arr 包含非负整数)的位置。这将创建一个与arr 形状相同的中间数组。然后,它对这个数组求和。

直截了当,对吧?好吧,当使用np.sum(arr > 0) 时,这是一个大数组。当它大到不适合缓存时,性能会下降,因为当处理器开始执行时,大多数数组元素将被从内存中逐出并需要重新加载。

由于f_2 迭代第一个维度,它正在处理更小的子数组。完成了相同的复制和求和,但这次中间数组适合内存。它在不留记忆的情况下被创建、使用和销毁。这要快得多。

现在,您可能认为f_3 会最快(使用内置方法等等),但查看source code 表明它使用以下操作:

a_bool = a.astype(np.bool_, copy=False)
return a_bool.sum(axis=axis, dtype=np.intp

a_bool 只是另一种查找非零条目的方法,它会创建一个大型中间数组。

结论

经验法则就是这样,而且经常是错误的。如果您想要更快的代码,请对其进行分析并查看问题所在(在此处进行了很好的工作)。

Python 在某些方面做得很好。在优化的情况下,它可以比numpy 更快。不要害怕将普通的旧 python 代码或数据类型与 numpy 结合使用。

如果您发现自己经常手动编写 for 循环以获得更好的性能,您可能需要查看 numexpr - 它会自动执行其中的一些操作。我自己并没有太多使用它,但如果中间数组会减慢您的程序速度,它应该会提供很好的加速。

【讨论】:

  • 好收获。事实上,通过删除 > 0 测试 numpy 是最快的
  • 感谢您的回复。你能解释一下缓存是什么意思吗?我问是因为data 大约是 3.7GB,data > 0 大约是 0.9GB,data[0, :, :] 大约是 3.8MB。
  • @Poolka,这是 CPU 缓存造成的。我不知道确切的数字,但通常它有几 MB 的内存(这是我用谷歌搜索的一个页面,其中有一些更好的数字:makeuseof.com/tag/what-is-cpu-cache)。
  • 对数组轴的依赖有什么想法@BlackBear 在他的回答中注意到了吗?运行时间增加了十倍,这让我很困惑。
  • @Poolka,Numpy 在内存布局方面做了正确的事情,并且问题中给出的每个选项都以相同的方式访问内存,所以这不是问题。
【解决方案2】:

这完全取决于数据在内存中的布局方式以及代码如何访问它。本质上,数据是以块的形式从内存中获取的,然后被缓存;如果算法设法使用缓存中的块中的数据,则无需再次从内存中读取。这可以节省大量时间,尤其是当缓存远小于您正在处理的数据时。

考虑这些变化,它们仅在我们迭代的轴上有所不同:

def f_2_0(arr):
    ans = 0
    for val in range(arr.shape[0]):
        ans += np.sum(arr[val, :, :] > 0)
    return ans

def f_2_1(arr):
    ans = 0
    for val in range(arr.shape[1]):
        ans += np.sum(arr[:, val, :] > 0)
    return ans

def f_2_2(arr):
    ans = 0
    for val in range(arr.shape[2]):
        ans += np.sum(arr[:, :, val] > 0)
    return ans

我的笔记本电脑上的结果:

%timeit f_1(data)
2.31 s ± 47.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit f_2_0(data)
1.88 s ± 60 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit f_2_1(data)
2.65 s ± 142 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit f_2_2(data)
12.8 s ± 650 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

你可以看到f_2_1 几乎和f_1 一样快,这让我觉得 numpy 没有使用最佳访问模式(f_2_0 使用的那个)other answer 中解释了缓存究竟如何影响时间。

【讨论】:

  • 感谢您的回答。我看到您编辑了答案并提到了另一个答案。但是,您注意到对数组轴的强烈依赖。将其命名为 numpy 问题是否正确?
  • @Poolka 这是一个 numpy 问题,因为data > 0 创建了一个copy,而不是一个view(与索引相反),但这不是迭代问题:如果您运行相同的测试并删除 > 0 部分,您将看到 numpy 最快(虽然不是很多,我的结果是 500ms、537ms、945ms 和 10.8s)
【解决方案3】:

让我们彻底删除临时数组

正如@user2699 在他的回答中已经提到的那样,分配和写入一个不适合缓存的大型数组会大大减慢这个过程。为了展示这种行为,我使用 Numba(JIT 编译器)编写了两个小函数。

在编译语言(C、Fortran、..)中,您通常会避免使用临时数组。在解释型 Python(不使用 Cython 或 Numba)中,您通常希望在更大的数据块(向量化)上调用编译函数,因为解释型代码中的循环非常慢。但这也有不利的一面(如临时数组、缓存使用不当)

无需临时数组分配的函数

@nb.njit(fastmath=True,parallel=False)
def f_4(arr):
    sum=0
    for i in nb.prange(arr.shape[0]):
        for j in range(arr.shape[1]):
            for k in range(arr.shape[2]):
                if arr[i,j,k]>0:
                    sum+=1
    return sum

带临时数组

请注意,如果开启并行化parallel=True,编译器不仅会尝试并行化代码,还会开启循环融合等其他优化。

@nb.njit(fastmath=True,parallel=False)
def f_5(arr):
    return np.sum(arr>0)

时间安排

%timeit f_1(data)
1.65 s ± 48.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f_2(data)
1.27 s ± 5.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f_3(data)
1.99 s ± 6.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit f_4(data) #parallel=false
216 ms ± 5.45 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f_4(data) #parallel=true
121 ms ± 4.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f_5(data) #parallel=False
1.12 s ± 19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f_5(data) #parallel=true Temp-Array is automatically optimized away
146 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

【讨论】:

    猜你喜欢
    • 2017-11-09
    • 2020-12-31
    • 1970-01-01
    • 2021-01-16
    • 2017-12-17
    • 2019-03-17
    • 1970-01-01
    • 2016-05-03
    • 1970-01-01
    相关资源
    最近更新 更多