【发布时间】:2017-12-15 00:37:25
【问题描述】:
我正在努力使用 Cython 提高我的 python 粒子跟踪代码的性能。
这是我的纯 Python 代码:
from scipy.integrate import odeint
import numpy as np
from numpy import sqrt, pi, sin, cos
from time import time as Time
import multiprocessing as mp
from functools import partial
cLight = 299792458.
Dim = 6
class Integrator:
def __init__(self, ring):
self.ring = ring
def equations(self, X, s):
dXds = np.zeros(Dim)
E, B = self.ring.getEMField( [X[0], X[2], s], X[4] )
h = 1 + X[0]/self.ring.ringRadius
p_s = np.sqrt(X[5]**2 - self.ring.particle.mass**2 - X[1]**2 - X[3]**2)
dtds = h*X[5]/p_s
gamma = X[5]/self.ring.particle.mass
beta = np.array( [X[1], X[3], p_s] ) / X[5]
dXds[0] = dtds*beta[0]
dXds[2] = dtds*beta[1]
dXds[1] = p_s/self.ring.ringRadius + self.ring.particle.charge*(dtds*E[0] + dXds[2]*B[2] - h*B[1])
dXds[3] = self.ring.particle.charge*(dtds*E[1] + h*B[0] - dXds[0]*B[2])
dXds[4] = dtds
dXds[5] = self.ring.particle.charge*(dXds[0]*E[0] + dXds[2]*E[1] + h*E[2])
return dXds
def odeSolve(self, X0, sRange):
sol = odeint(self.equations, X0, sRange)
return sol
class Ring:
def __init__(self, particle):
self.particle = particle
self.ringRadius = 7.112
self.magicB0 = self.particle.magicMomentum/self.ringRadius
def getEMField(self, pos, time):
x, y, s = pos
theta = (s/self.ringRadius*180/pi) % 360
r = sqrt(x**2 + y**2)
arg = 0 if r == 0 else np.angle( complex(x/r, y/r) )
rn = r/0.045
k2 = 37*24e3
k10 = -4*24e3
E = np.zeros(3)
B = np.array( [ 0, self.magicB0, 0 ] )
for i in range(4):
if ((21.9+90*i < theta < 34.9+90*i or 38.9+90*i < theta < 64.9+90*i) and (-0.05 < x < 0.05 and -0.05 < y < 0.05)):
E = np.array( [ k2*x/0.045 + k10*rn**9*cos(9*arg), -k2*y/0.045 -k10*rn**9*sin(9*arg), 0] )
break
return E, B
class Particle:
def __init__(self):
self.mass = 105.65837e6
self.charge = 1.
self.gm2 = 0.001165921
self.magicMomentum = self.mass/sqrt(self.gm2)
self.magicEnergy = sqrt(self.magicMomentum**2 + self.mass**2)
self.magicGamma = self.magicEnergy/self.mass
self.magicBeta = self.magicMomentum/(self.magicGamma*self.mass)
def runSimulation(nParticles, tEnd):
particle = Particle()
ring = Ring(particle)
integrator = Integrator(ring)
Xs = np.array( [ np.array( [45e-3*(np.random.rand()-0.5)*2, 0, 0, 0, 0, particle.magicEnergy] ) for i in range(nParticles) ] )
sRange = np.arange(0, tEnd, 1e-9)*particle.magicBeta*cLight
ode = partial(integrator.odeSolve, sRange=sRange)
t1 = Time()
pool = mp.Pool()
sol = np.array(pool.map(ode, Xs))
t2 = Time()
print ("%.3f sec" %(t2-t1))
return t2-t1
显然,最耗时的过程是积分 ODE,在 Integrator 类中定义为 odeSolve() 和 equations()。此外,在求解过程中,类 Ring 中的 getEMField() 方法与 equations() 方法一样多。 我尝试使用 Cython 获得显着的加速(至少 10 倍~20 倍),但通过以下 Cython 脚本我只获得了约 1.5 倍的加速水平:
import cython
import numpy as np
cimport numpy as np
from libc.math cimport sqrt, pi, sin, cos
from scipy.integrate import odeint
from time import time as Time
import multiprocessing as mp
from functools import partial
cdef double cLight = 299792458.
cdef int Dim = 6
@cython.boundscheck(False)
cdef class Integrator:
cdef Ring ring
def __init__(self, ring):
self.ring = ring
cpdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] equations(self,
np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] X,
double s):
cdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] dXds = np.zeros(Dim)
cdef double h, p_s, dtds, gamma
cdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] beta, E, B
E, B = self.ring.getEMField( [X[0], X[2], s], X[4] )
h = 1 + X[0]/self.ring.ringRadius
p_s = np.sqrt(X[5]*X[5] - self.ring.particle.mass*self.ring.particle.mass - X[1]*X[1] - X[3]*X[3])
dtds = h*X[5]/p_s
gamma = X[5]/self.ring.particle.mass
beta = np.array( [X[1], X[3], p_s] ) / X[5]
dXds[0] = dtds*beta[0]
dXds[2] = dtds*beta[1]
dXds[1] = p_s/self.ring.ringRadius + self.ring.particle.charge*(dtds*E[0] + dXds[2]*B[2] - h*B[1])
dXds[3] = self.ring.particle.charge*(dtds*E[1] + h*B[0] - dXds[0]*B[2])
dXds[4] = dtds
dXds[5] = self.ring.particle.charge*(dXds[0]*E[0] + dXds[2]*E[1] + h*E[2])
return dXds
cpdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] odeSolve(self,
np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] X0,
np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] sRange):
sol = odeint(self.equations, X0, sRange)
return sol
@cython.boundscheck(False)
cdef class Ring:
cdef Particle particle
cdef double ringRadius
cdef double magicB0
def __init__(self, particle):
self.particle = particle
self.ringRadius = 7.112
self.magicB0 = self.particle.magicMomentum/self.ringRadius
cpdef tuple getEMField(self,
list pos,
double time):
cdef double x, y, s
cdef double theta, r, rn, arg, k2, k10
cdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] E, B
x, y, s = pos
theta = (s/self.ringRadius*180/pi) % 360
r = sqrt(x*x + y*y)
arg = 0 if r == 0 else np.angle( complex(x/r, y/r) )
rn = r/0.045
k2 = 37*24e3
k10 = -4*24e3
E = np.zeros(3)
B = np.array( [ 0, self.magicB0, 0 ] )
for i in range(4):
if ((21.9+90*i < theta < 34.9+90*i or 38.9+90*i < theta < 64.9+90*i) and (-0.05 < x < 0.05 and -0.05 < y < 0.05)):
E = np.array( [ k2*x/0.045 + k10*rn**9*cos(9*arg), -k2*y/0.045 -k10*rn**9*sin(9*arg), 0] )
#E = np.array( [ k2*x/0.045, -k2*y/0.045, 0] )
break
return E, B
cdef class Particle:
cdef double mass
cdef double charge
cdef double gm2
cdef double magicMomentum
cdef double magicEnergy
cdef double magicGamma
cdef double magicBeta
def __init__(self):
self.mass = 105.65837e6
self.charge = 1.
self.gm2 = 0.001165921
self.magicMomentum = self.mass/sqrt(self.gm2)
self.magicEnergy = sqrt(self.magicMomentum**2 + self.mass**2)
self.magicGamma = self.magicEnergy/self.mass
self.magicBeta = self.magicMomentum/(self.magicGamma*self.mass)
def runSimulation(nParticles, tEnd):
particle = Particle()
ring = Ring(particle)
integrator = Integrator(ring)
#nParticles = 5
Xs = np.array( [ np.array( [45e-3*(np.random.rand()-0.5)*2, 0, 0, 0, 0, particle.magicEnergy] ) for i in range(nParticles) ] )
sRange = np.arange(0, tEnd, 1e-9)*particle.magicBeta*cLight
ode = partial(integrator.odeSolve, sRange=sRange)
t1 = Time()
pool = mp.Pool()
sol = np.array(pool.map(ode, Xs))
t2 = Time()
print ("%.3f sec" %(t2-t1))
return t2-t1
我应该怎么做才能让 Cython 发挥最大的作用? (我尝试使用 Numba 代替 Cython,实际上 Numba 的性能提升是巨大的(大约 20 倍加速)。但我很难将 Numba 与 python 类实例一起使用,因此我决定使用 Cython 而不是 Numba)。
【问题讨论】:
-
您是否对代码进行了基准测试以找到瓶颈?通过快速阅读,对我来说,Cython 或 Numba 是否能够提供很大的加速并不是很明显:您的大部分操作已经以矢量化方式完成。我会先使用line profiler 找出慢点在哪里。
-
@jakevdp 感谢您的评论。我查找使用 line profiler,但似乎我首先需要学习如何在 Cython 和 Python3 上使用它......这需要一些时间。如果有帮助,我添加了带有注释模式的 Cython 编译结果。
-
我认为他建议在您的原始/非 cython 代码上使用 line profiler 来查看哪些操作很慢。如果这些是基本的 numpy-primitives / 矢量化部分,你知道 cython 不会有帮助。
-
您可能会一直在 odeint 中使用?最好的选择是进行多处理?我想剩下的会很快?
-
是的,我建议在将 Python 代码转换为 Cython 之前对其进行分析。我怀疑瓶颈出现在 cython 和 numba 帮助有限的地方。
标签: python numpy scipy cython cythonize