【问题标题】:Huge speed difference in numpy between similar code类似代码之间numpy的巨大速度差异
【发布时间】:2015-08-18 23:04:27
【问题描述】:

为什么下面的L2范数计算会有这么大的速度差异:

a = np.arange(1200.0).reshape((-1,3))

%timeit [np.sqrt((a*a).sum(axis=1))]
100000 loops, best of 3: 12 µs per loop

%timeit [np.sqrt(np.dot(x,x)) for x in a]
1000 loops, best of 3: 814 µs per loop

%timeit [np.linalg.norm(x) for x in a]
100 loops, best of 3: 2 ms per loop

据我所知,所有三个都产生相同的结果。

这里是 numpy.linalg.norm 函数的源代码:

x = asarray(x)

# Check the default case first and handle it immediately.
if ord is None and axis is None:
    x = x.ravel(order='K')
    if isComplexType(x.dtype.type):
        sqnorm = dot(x.real, x.real) + dot(x.imag, x.imag)
    else:
        sqnorm = dot(x, x)
    return sqrt(sqnorm)

编辑:有人建议可以并行化一个版本,但我检查了一下,事实并非如此。所有三个版本都消耗 12.5% 的 CPU(这通常是我的 4 个物理/8 个虚拟核 Xeon CPU 上的 Python 代码的情况。

【问题讨论】:

  • 还有几个时间:[math.sqrt(np.dot(x,x)) for x in a], np.sqrt(np.einsum('ij,ij->i',a,a))
  • 主要区别在于解释的 Python 代码和编译后的 C 代码所做的事情。
  • 我注意到的一件事是,第一种方法给出的结果比其他方法的精度要低得多。例如第一种方法产生的最终数字是2074.9973494,而后两种方法产生的最终数字是2074.9973493958973。
  • Tris Nefzger,我用 dtype 检查了结果,在所有三种情况下都是 float64。
  • 列表和数组有不同的规则来显示浮点数的有效数字。所以显示并没有告诉你很多浮点类型。

标签: python performance numpy


【解决方案1】:

np.dot 通常会调用 BLAS 库函数 - 因此它的速度将取决于您的 numpy 版本链接到的 BLAS 库。一般来说,我希望它具有更大的恒定开销,但随着数组大小的增加,可以更好地扩展。但是,您从列表解析中调用它(实际上是一个普通的 Python for 循环)这一事实可能会抵消使用 BLAS 的任何性能优势。

如果您摆脱列表理解并使用 axis= kwarg,np.linalg.norm 与您的第一个示例相当,但 np.einsum 比两者都快得多:

In [1]: %timeit np.sqrt((a*a).sum(axis=1))
The slowest run took 10.12 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 11.1 µs per loop

In [2]: %timeit np.linalg.norm(a, axis=1)
The slowest run took 14.63 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 13.5 µs per loop

# this is what np.linalg.norm does internally
In [3]: %timeit np.sqrt(np.add.reduce(a * a, axis=1))
The slowest run took 34.05 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 10.7 µs per loop

In [4]: %timeit np.sqrt(np.einsum('ij,ij->i',a,a))
The slowest run took 5.55 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 5.42 µs per loop

【讨论】:

  • 我很好奇,当我们用 axis=1 参数替换 Python 列表理解时,底层会发生什么变化?或者,为什么 Python 的“for 循环”比 C 的“for 循环”慢?
  • 回答您的第一个问题,当您使用矢量化(即axis=1 而不是列表 comp)时,numpy 本质上是在 C 代码级别而不是在 Python 中对数组元素进行循环。至于为什么在 C 中循环比 Python 快……这是一个很大的问题。简单的答案是,较慢的运行时性能是您为换取 Python 出色的高级语言功能(如类型检查、自动垃圾收集等)而付出的代价(请参阅this Q/A 了解更多详细信息)。
  • 矢量化是指英特尔 AVX 吗?关于 Python 与 C,我不太清楚:Python 字节码是否仍在进行所有检查/垃圾收集/等,还是仅在编译字节码期间完成一次?如果字节码没有这种开销,是什么让它比本机机器码慢?
  • 在 numpy-land 中,我们倾向于松散地使用“矢量化”来指代array programming。 Python 编译器生成的字节码与 C 等编译成的本机指令不同,而是一组用于在运行时进行类型检查、垃圾收集等的虚拟机指令。除了标准 CPython 之外,还有其他 Python 实现尝试各种运行时优化(例如,您可能听说过 PyPy,它使用优化的 JIT 编译器)。
  • 这一切都方式超出了我可以在 SE 问题的 cmets 中合理回答的范围。您可以找到有关 CPython 内部 herehere 的更多详细信息
猜你喜欢
  • 1970-01-01
  • 2014-07-06
  • 2019-12-28
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2012-12-07
  • 2014-06-20
  • 1970-01-01
相关资源
最近更新 更多