【发布时间】:2019-09-22 23:48:13
【问题描述】:
我正在尝试加快 3d 数组中沿 Z 轴的平均值的计算。我阅读了 cython 的文档以添加类型、内存视图等,以完成此任务。但是,当我比较两者时:基于 numpy 的函数和基于 .so 文件中的 cython 语法和编译的函数,第一个胜过第二个。我的代码中是否有步骤或类型声明出错/丢失?
这是我的 numpy 版本:python_mean.py
import numpy as np
def mean_py(array):
x = array.shape[1]
y = array.shape[2]
values = []
for i in range(x):
for j in range(y):
values.append((np.mean(array[:, i, j])))
values = np.array([values])
values = values.reshape(500,500)
return values
这是我的 cython_mean.pyx 文件
%%cython
from cython import wraparound, boundscheck
import numpy as np
cimport numpy as np
DTYPE = np.double
@boundscheck(False)
@wraparound(False)
def cy_mean(double[:,:,:] array):
cdef Py_ssize_t x_max = array.shape[1]
cdef Py_ssize_t y_max = array.shape[2]
cdef double[:,:] result = np.zeros([x_max, y_max], dtype = DTYPE)
cdef double[:,:] result_view = result
cdef Py_ssize_t i,j
cdef double mean
cdef list values
for i in range(x_max):
for j in range(y_max):
mean = np.mean(array[:,i,j])
result_view[i,j] = mean
return result
当我导入这两个函数并开始对 3D numpy 数组进行计算时,我得到以下结果:
import numpy as np
a = np.random.randn(250_000)
b = np.random.randn(250_000)
c = np.random.randn(250_000)
array = np.vstack((a,b,c)).reshape(3, 500, 500)
import mean_py
from mean_py import mean_py
%timeit mean_py(array)
4.82 s ± 84.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
import cython_mean
from cython_mean import cy_mean
7.3 s ± 499 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
为什么 cython 代码的性能如此之低? 感谢您的帮助
【问题讨论】:
-
Cython 是优化 Python 代码的通用解决方案。 NumPy 是数学计算的特定解决方案。所以,对于数学计算,NumPy 在大多数情况下应该会胜出……
-
总是使用 %%cython -a 看看到底发生了什么。问题是使用 np.mean()。如果你在循环中写出 np.mean,你可以很容易地达到 Numpy 的性能(很可能 numpy 的实现也是用 Cython 编写的)。
标签: python arrays numpy cython mean