【发布时间】:2014-10-06 18:41:48
【问题描述】:
我有一段代码正在尝试优化。大部分代码执行时间由cdef np.ndarray index = np.argwhere(array==1) 占用
其中 array 是一个 numpy 是一个 512x512,512 由零和一组成的 numpy 数组。有什么想法可以加快速度吗?使用 Python 2.7、Numpy 1.8.1
球形函数
def sphericity(self,array):
#Pass an mask array (1's are marked, 0's ignored)
cdef np.ndarray index = np.argwhere(array==1)
cdef int xSize,ySize,zSize
xSize,ySize,zSize=array.shape
cdef int sa,vol,voxelIndex,x,y,z,neighbors,xDiff,yDiff,zDiff,x1,y1,z1
cdef float onethird,twothirds,sp
sa=vol=0 #keep running tally of volume and surface area
#cdef int nonZeroCount = (array != 0).sum() #Replaces np.count_nonzero(array) for speed
for voxelIndex in range(np.count_nonzero(array)):
#for voxelIndex in range(nonZeroCount):
x=index[voxelIndex,0]
y=index[voxelIndex,1]
z=index[voxelIndex,2]
#print x,y,z,array[x,y,z]
neighbors=0
vol+=1
for xDiff in [-1,0,1]:
for yDiff in [-1,0,1]:
for zDiff in [-1,0,1]:
if abs(xDiff)+abs(yDiff)+abs(zDiff)==1:
x1=x+xDiff
y1=y+yDiff
z1=z+zDiff
if x1>=0 and y1>=0 and z1>=0 and x1<xSize and y1<ySize and z1<zSize:
#print '-',x1,y1,z1,array[x1,y1,z1]
if array[x1,y1,z1]:
#print '-',x1,y1,z1,array[x1,y1,z1]
neighbors+=1
#print 'had this many neighbors',neighbors
sa+=(6-neighbors)
onethird=float(1)/float(3)
twothirds=float(2)/float(3)
sph = ((np.pi**onethird)*((6*vol)**twothirds)) / sa
#print 'sphericity',sphericity
return sph
分析测试
#Imports
import pstats, cProfile
import numpy as np
import pyximport
pyximport.install(setup_args={"script_args":["--compiler=mingw32"], "include_dirs":np.get_include()}, reload_support=True) #Generate cython version
#Create fake array to calc sphericity
fakeArray=np.zeros((512,512,512))
fakeArray[200:300,200:300,200:300]=1
#Profiling stuff
cProfile.runctx("sphericity(fakeArray)", globals(), locals(), "Profile.prof")
s = pstats.Stats("Profile.prof")
s.strip_dirs().sort_stats("time").print_stats()
分析的输出
Mon Oct 06 11:49:57 2014 Profile.prof
12 function calls in 4.373 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
1 3.045 3.045 4.373 4.373 <string>:1(<module>)
1 1.025 1.025 1.025 1.025 {method 'nonzero' of 'numpy.ndarray' objects}
2 0.302 0.151 0.302 0.151 {numpy.core.multiarray.array}
1 0.001 0.001 1.328 1.328 numeric.py:731(argwhere)
1 0.000 0.000 0.302 0.302 fromnumeric.py:492(transpose)
1 0.000 0.000 0.302 0.302 fromnumeric.py:38(_wrapit)
1 0.000 0.000 0.000 0.000 {method 'transpose' of 'numpy.ndarray' objects}
1 0.000 0.000 0.302 0.302 numeric.py:392(asarray)
1 0.000 0.000 0.000 0.000 numeric.py:462(asanyarray)
1 0.000 0.000 0.000 0.000 {getattr}
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects}
【问题讨论】:
-
请说明您的开发环境(操作系统、Python版本、Numpy版本、Cython版本)
-
您能否提供初始化“数组”变量的代码位?最后一个问题:你怎么知道这是你的瓶颈?你尝试了什么?