【问题标题】:Python / Cython / Numpy optimization of np.nonzeronp.nonzero 的 Python / Cython / Numpy 优化
【发布时间】:2014-10-06 18:41:48
【问题描述】:

我有一段代码正在尝试优化。大部分代码执行时间由cdef np.ndarray index = np.argwhere(array==1) 占用 其中 array 是一个 numpy 是一个 512x512,512 由零和一组成的 numpy 数组。有什么想法可以加快速度吗?使用 Python 2.7、Numpy 1.8.1

球形函数

def sphericity(self,array):

    #Pass an mask array (1's are marked, 0's ignored)
    cdef np.ndarray index = np.argwhere(array==1)
    cdef int xSize,ySize,zSize
    xSize,ySize,zSize=array.shape

    cdef int sa,vol,voxelIndex,x,y,z,neighbors,xDiff,yDiff,zDiff,x1,y1,z1
    cdef float onethird,twothirds,sp
    sa=vol=0 #keep running tally of volume and surface area
    #cdef int nonZeroCount = (array != 0).sum() #Replaces np.count_nonzero(array) for speed
    for voxelIndex in range(np.count_nonzero(array)):
    #for voxelIndex in range(nonZeroCount):
        x=index[voxelIndex,0]
        y=index[voxelIndex,1]
        z=index[voxelIndex,2]
        #print x,y,z,array[x,y,z]
        neighbors=0
        vol+=1

        for xDiff in [-1,0,1]:
            for yDiff in [-1,0,1]:
                for zDiff in [-1,0,1]:
                    if abs(xDiff)+abs(yDiff)+abs(zDiff)==1:
                        x1=x+xDiff
                        y1=y+yDiff
                        z1=z+zDiff
                        if x1>=0 and y1>=0 and z1>=0 and x1<xSize and y1<ySize and z1<zSize:
                            #print '-',x1,y1,z1,array[x1,y1,z1]
                            if array[x1,y1,z1]:
                                #print '-',x1,y1,z1,array[x1,y1,z1]
                                neighbors+=1

        #print 'had this many neighbors',neighbors
        sa+=(6-neighbors)

    onethird=float(1)/float(3)
    twothirds=float(2)/float(3)
    sph = ((np.pi**onethird)*((6*vol)**twothirds)) / sa
    #print 'sphericity',sphericity
    return sph

分析测试

#Imports
import pstats, cProfile
import numpy as np
import pyximport
pyximport.install(setup_args={"script_args":["--compiler=mingw32"], "include_dirs":np.get_include()}, reload_support=True) #Generate cython version

#Create fake array to calc sphericity
fakeArray=np.zeros((512,512,512))
fakeArray[200:300,200:300,200:300]=1

#Profiling stuff
cProfile.runctx("sphericity(fakeArray)", globals(), locals(), "Profile.prof")
s = pstats.Stats("Profile.prof")
s.strip_dirs().sort_stats("time").print_stats()

分析的输出

Mon Oct 06 11:49:57 2014    Profile.prof

         12 function calls in 4.373 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    3.045    3.045    4.373    4.373 <string>:1(<module>)
        1    1.025    1.025    1.025    1.025 {method 'nonzero' of 'numpy.ndarray' objects}
        2    0.302    0.151    0.302    0.151 {numpy.core.multiarray.array}
        1    0.001    0.001    1.328    1.328 numeric.py:731(argwhere)
        1    0.000    0.000    0.302    0.302 fromnumeric.py:492(transpose)
        1    0.000    0.000    0.302    0.302 fromnumeric.py:38(_wrapit)
        1    0.000    0.000    0.000    0.000 {method 'transpose' of 'numpy.ndarray' objects}
        1    0.000    0.000    0.302    0.302 numeric.py:392(asarray)
        1    0.000    0.000    0.000    0.000 numeric.py:462(asanyarray)
        1    0.000    0.000    0.000    0.000 {getattr}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

【问题讨论】:

  • 请说明您的开发环境(操作系统、Python版本、Numpy版本、Cython版本)
  • 您能否提供初始化“数组”变量的代码位?最后一个问题:你怎么知道这是你的瓶颈?你尝试了什么?

标签: python numpy cython


【解决方案1】:

Jaime 可能给出了一个很好的答案,但我会评论改进 Cython 代码并添加性能比较。

首先,您应该使用“注释”功能,cython -a filename.pyx,这将生成一个 HTML 文件。在浏览器中加载它,它会用黄橙色突出显示“慢”行,这表明可以改进的地方。

Annotate 立即揭示了两个很容易修复的东西:

将习语转换为 cython 可以理解的内容

首先,这些行很慢:

        for xDiff in [-1,0,1]:
            for yDiff in [-1,0,1]:
                for zDiff in [-1,0,1]:

原因是 Cython 不知道如何将列表迭代转换为干净的 c 代码。需要将其转化为 Cython 可以优化的等效代码,即“范围内”形式:

        for xDiff in range(-1, 2):
            for yDiff in range(-1, 2):
                for zDiff in range(-1, 2):

为快速索引键入数组

接下来就是这条线很慢:

                            if array[x1,y1,z1]:

原因是array 没有被赋予类型。因此,它使用 python 级索引而不是 c 级索引。为了解决这个问题,你需要给数组一个类型,这可以通过这种方式完成:

def sphericity(np.ndarray[np.uint8_t, ndim=3] array):

假设数组的类型是'uint8',替换成合适的类型(注意:Cython不支持'np.bool'类型,所以我使用'uint8')

你也可以使用内存视图,你不能在内存视图上使用 numpy 函数,但是你可以在数组上创建一个视图,然后索引视图而不是数组:

    cdef np.uint8_t array_view [:, :, :] = array
    ...
                                    if array_view[x1,y1,z1]:

内存视图可能会稍微快一些,并且在数组(python 级别调用)和视图(c 级别调用)之间进行了清晰的划分。如果你不使用 numpy 函数,你可以使用内存视图没有问题。

重写代码以避免多次遍历数组

剩下的是计算 indexnonZeroCount 都很慢,这是出于各种原因,但主要与数据的绝对大小有关(本质上,迭代 512*512*512 元素只需要时间!) 一般来说,Numpy 可以做的任何事情,优化的 Cython 都可以做得更快(通常快 2-10 倍) - numpy 只会为您节省大量重新发明轮子和大量打字的时间,并让您在更高层次上思考(如果您不是也是一个 c 程序员,你可能无法很好地优化 cython)。但在这种情况下很容易,您可以删除 indexnonZeroCount 以及所有相关代码,然后执行以下操作:

    for x in range(0, xSize):
        for y in range(0, ySize):
            for z in range(0, zSize):
                if array[x,y,z] == 0:
                    continue
                ... 

这是非常快的,因为 c(干净的 Cython 可以完美地编译)每秒执行数十亿次操作没有问题。通过消除indexnonZeroCount 步骤,您基本上可以节省整个阵列上的两个完整迭代,即使在最大速度下,每个迭代也至少需要大约0.1 秒。更重要的是 CPU 缓存,整个阵列是 128mb,比一个 cpu 缓存大得多,因此一次完成所有操作可以更好地利用 cpu 缓存(如果阵列完全适合 CPU,那么多次通过就没有那么重要了缓存)。

优化版

这是我优化版本的完整代码:

#cython: boundscheck=False, nonecheck=False, wraparound=False
import numpy as np
cimport numpy as np

def sphericity2(np.uint8_t [:, :, :] array):

    #Pass an mask array (1's are marked, 0's ignored)
    cdef int xSize,ySize,zSize
    xSize=array.shape[0]
    ySize=array.shape[1]
    zSize=array.shape[2]

    cdef int sa,vol,x,y,z,neighbors,xDiff,yDiff,zDiff,x1,y1,z1
    cdef float onethird,twothirds,sp
    sa=vol=0 #keep running tally of volume and surface area

    for x in range(0, xSize):
        for y in range(0, ySize):
            for z in range(0, zSize):
                if array[x,y,z] == 0:
                    continue

                neighbors=0
                vol+=1

                for xDiff in range(-1, 2):
                    for yDiff in range(-1, 2):
                        for zDiff in range(-1, 2):
                            if abs(xDiff)+abs(yDiff)+abs(zDiff)==1:
                                x1=x+xDiff
                                y1=y+yDiff
                                z1=z+zDiff
                                if x1>=0 and y1>=0 and z1>=0 and x1<xSize and y1<ySize and z1<zSize:
                                    #print '-',x1,y1,z1,array[x1,y1,z1]
                                    if array[x1,y1,z1]:
                                        #print '-',x1,y1,z1,array[x1,y1,z1]
                                        neighbors+=1

                #print 'had this many neighbors',neighbors
                sa+=(6-neighbors)

    onethird=float(1)/float(3)
    twothirds=float(2)/float(3)
    sph = ((np.pi**onethird)*((6*vol)**twothirds)) / sa
    #print 'sphericity',sphericity
    return sph

球形执行时间对比:

原始:2.123s 詹姆的 : 1.819s 优化 Cython : 0.136s @moarningsun:0.090s

在所有 Cython 解决方案中,运行速度提高了 15 倍以上,展开内循环(见评论)运行速度提高了 23 倍以上。

【讨论】:

  • 您可以通过对所有六个方向执行if x &gt; 0: neighbors += array[x-1,y,z] 之类的操作来轻松摆脱最内层的三个循环。它更快(我试过了),但你当然会失去一些普遍性。无论如何,很好的答案!
  • @moarningsun 我已经添加了消除内部循环和检查的时间
  • 好的,太好了!在我自己的时间里,我发现对于较小的阵列来说加速更为明显。对于更大的阵列,我猜内存带宽会成为瓶颈。
  • 问题:使用 def sphericity2(np.uint8_t [:, :, :] 数组):我收到此错误,ValueError: 不理解字符缓冲区 dtype 格式字符串 ('?')。我做错了什么?
  • @Dbricks 我认为这是因为 numpy 数组没有被赋予适当的类型,在我的测试代码中我使用了fakeArray=np.zeros((512,512,512), np.uint8)
【解决方案2】:

您可以从 vanilla numpy 获得大部分代码尝试执行的操作,而无需 Cython。核心是获得一种计算邻居的有效方法,这可以通过and-ing 从输入数组获得的掩码切片来完成。综上所述,我认为以下代码与您的代码相同,但重复次数要少得多:

def sphericity(arr):
    mask = arr != 0
    vol = np.count_nonzero(mask)
    counts = np.zeros_like(arr, dtype=np.intp)
    for dim, size in enumerate(arr.shape):
        slc = (slice(None),) * dim
        axis_mask = (mask[slc + (slice(None, -1),)] &
                     mask[slc + (slice(1, None),)])
        counts[slc + (slice(None, -1),)] += axis_mask
        counts[slc + (slice(1, None),)] += axis_mask
    sa = np.sum(6 - counts[counts != 0])

    return np.pi**(1./3.)*(6*vol)**(2./3.) / sa

【讨论】:

    猜你喜欢
    • 2011-07-16
    • 1970-01-01
    • 1970-01-01
    • 2021-04-19
    • 1970-01-01
    • 1970-01-01
    • 2016-11-02
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多