【问题标题】:Cython optimization of the code代码的 Cython 优化
【发布时间】:2017-12-15 00:37:25
【问题描述】:

我正在努力使用 Cython 提高我的 python 粒子跟踪代码的性能。

这是我的纯 Python 代码:

from scipy.integrate import odeint
import numpy as np
from numpy import sqrt, pi, sin, cos
from time import time as Time
import multiprocessing as mp
from functools import partial

cLight = 299792458.
Dim = 6

class Integrator:
    def __init__(self, ring):
        self.ring = ring

    def equations(self, X, s):
        dXds = np.zeros(Dim)

        E, B = self.ring.getEMField( [X[0], X[2], s], X[4] )

        h = 1 + X[0]/self.ring.ringRadius
        p_s = np.sqrt(X[5]**2 - self.ring.particle.mass**2 - X[1]**2 - X[3]**2)
        dtds = h*X[5]/p_s
        gamma = X[5]/self.ring.particle.mass
        beta = np.array( [X[1], X[3], p_s] ) / X[5]

        dXds[0] = dtds*beta[0]
        dXds[2] = dtds*beta[1]
        dXds[1] = p_s/self.ring.ringRadius + self.ring.particle.charge*(dtds*E[0] + dXds[2]*B[2] - h*B[1])
        dXds[3] = self.ring.particle.charge*(dtds*E[1] + h*B[0] - dXds[0]*B[2])
        dXds[4] = dtds
        dXds[5] = self.ring.particle.charge*(dXds[0]*E[0] + dXds[2]*E[1] + h*E[2])
        return dXds

    def odeSolve(self, X0, sRange):
        sol = odeint(self.equations, X0, sRange)
        return sol

class Ring:
    def __init__(self, particle):
        self.particle = particle
        self.ringRadius = 7.112
        self.magicB0 = self.particle.magicMomentum/self.ringRadius

    def getEMField(self, pos, time):
        x, y, s = pos
        theta = (s/self.ringRadius*180/pi) % 360
        r = sqrt(x**2 + y**2)
        arg = 0 if r == 0 else np.angle( complex(x/r, y/r) )
        rn = r/0.045

        k2 = 37*24e3
        k10 = -4*24e3

        E = np.zeros(3)
        B = np.array( [ 0, self.magicB0, 0 ] )

        for i in range(4):
            if ((21.9+90*i < theta < 34.9+90*i or 38.9+90*i < theta < 64.9+90*i) and (-0.05 < x < 0.05 and -0.05 < y < 0.05)):
                E = np.array( [ k2*x/0.045 + k10*rn**9*cos(9*arg), -k2*y/0.045 -k10*rn**9*sin(9*arg), 0] )
                break
        return E, B

class Particle:
    def __init__(self):
        self.mass = 105.65837e6
        self.charge = 1.
        self.gm2 = 0.001165921 

        self.magicMomentum = self.mass/sqrt(self.gm2)
        self.magicEnergy = sqrt(self.magicMomentum**2 + self.mass**2)
        self.magicGamma = self.magicEnergy/self.mass
        self.magicBeta = self.magicMomentum/(self.magicGamma*self.mass)


def runSimulation(nParticles, tEnd):
    particle = Particle()
    ring = Ring(particle)
    integrator = Integrator(ring)

    Xs = np.array( [ np.array( [45e-3*(np.random.rand()-0.5)*2, 0, 0, 0, 0, particle.magicEnergy] ) for i in range(nParticles) ] )
    sRange = np.arange(0, tEnd, 1e-9)*particle.magicBeta*cLight 

    ode = partial(integrator.odeSolve, sRange=sRange)

    t1 = Time()

    pool = mp.Pool()
    sol = np.array(pool.map(ode, Xs))

    t2 = Time()
    print ("%.3f sec" %(t2-t1))

    return t2-t1

显然,最耗时的过程是积分 ODE,在 Integrator 类中定义为 odeSolve() 和 equations()。此外,在求解过程中,类 Ring 中的 getEMField() 方法与 equations() 方法一样多。 我尝试使用 Cython 获得显着的加速(至少 10 倍~20 倍),但通过以下 Cython 脚本我只获得了约 1.5 倍的加速水平:

import cython
import numpy as np
cimport numpy as np
from libc.math cimport sqrt, pi, sin, cos

from scipy.integrate import odeint
from time import time as Time
import multiprocessing as mp
from functools import partial

cdef double cLight = 299792458.
cdef int Dim = 6

@cython.boundscheck(False)
cdef class Integrator:
    cdef Ring ring

    def __init__(self, ring):
        self.ring = ring

    cpdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] equations(self,
                  np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] X,
                  double s):
        cdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] dXds = np.zeros(Dim)
        cdef double h, p_s, dtds, gamma
        cdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] beta, E, B

        E, B = self.ring.getEMField( [X[0], X[2], s], X[4] )

        h = 1 + X[0]/self.ring.ringRadius
        p_s = np.sqrt(X[5]*X[5] - self.ring.particle.mass*self.ring.particle.mass - X[1]*X[1] - X[3]*X[3])
        dtds = h*X[5]/p_s
        gamma = X[5]/self.ring.particle.mass
        beta = np.array( [X[1], X[3], p_s] ) / X[5]

        dXds[0] = dtds*beta[0]
        dXds[2] = dtds*beta[1]
        dXds[1] = p_s/self.ring.ringRadius + self.ring.particle.charge*(dtds*E[0] + dXds[2]*B[2] - h*B[1])
        dXds[3] = self.ring.particle.charge*(dtds*E[1] + h*B[0] - dXds[0]*B[2])
        dXds[4] = dtds
        dXds[5] = self.ring.particle.charge*(dXds[0]*E[0] + dXds[2]*E[1] + h*E[2])
        return dXds

    cpdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] odeSolve(self,
                 np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] X0,
                 np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] sRange):
        sol = odeint(self.equations, X0, sRange)
        return sol

@cython.boundscheck(False)
cdef class Ring:
    cdef Particle particle
    cdef double ringRadius
    cdef double magicB0

    def __init__(self, particle):
        self.particle = particle
        self.ringRadius = 7.112
        self.magicB0 = self.particle.magicMomentum/self.ringRadius

    cpdef tuple getEMField(self,
                   list pos,
                   double time):
        cdef double x, y, s
        cdef double theta, r, rn, arg, k2, k10
        cdef np.ndarray[np.double_t, ndim=1, negative_indices=False, mode="c"] E, B

        x, y, s = pos
        theta = (s/self.ringRadius*180/pi) % 360
        r = sqrt(x*x + y*y)
        arg = 0 if r == 0 else np.angle( complex(x/r, y/r) )
        rn = r/0.045

        k2 = 37*24e3
        k10 = -4*24e3

        E = np.zeros(3)
        B = np.array( [ 0, self.magicB0, 0 ] )

        for i in range(4):
            if ((21.9+90*i < theta < 34.9+90*i or 38.9+90*i < theta < 64.9+90*i) and (-0.05 < x < 0.05 and -0.05 < y < 0.05)):
                E = np.array( [ k2*x/0.045 + k10*rn**9*cos(9*arg), -k2*y/0.045 -k10*rn**9*sin(9*arg), 0] )
                #E = np.array( [ k2*x/0.045, -k2*y/0.045, 0] )
                break
        return E, B

cdef class Particle:
    cdef double mass
    cdef double charge
    cdef double gm2

    cdef double magicMomentum
    cdef double magicEnergy
    cdef double magicGamma
    cdef double magicBeta

    def __init__(self):
        self.mass = 105.65837e6
        self.charge = 1.
        self.gm2 = 0.001165921 

        self.magicMomentum = self.mass/sqrt(self.gm2)
        self.magicEnergy = sqrt(self.magicMomentum**2 + self.mass**2)
        self.magicGamma = self.magicEnergy/self.mass
        self.magicBeta = self.magicMomentum/(self.magicGamma*self.mass)

def runSimulation(nParticles, tEnd):
    particle = Particle()
    ring = Ring(particle)
    integrator = Integrator(ring)

    #nParticles = 5
    Xs = np.array( [ np.array( [45e-3*(np.random.rand()-0.5)*2, 0, 0, 0, 0, particle.magicEnergy] ) for i in range(nParticles) ] )
    sRange = np.arange(0, tEnd, 1e-9)*particle.magicBeta*cLight 

    ode = partial(integrator.odeSolve, sRange=sRange)

    t1 = Time()

    pool = mp.Pool()
    sol = np.array(pool.map(ode, Xs))

    t2 = Time()
    print ("%.3f sec" %(t2-t1))

    return t2-t1

我应该怎么做才能让 Cython 发挥最大的作用? (我尝试使用 Numba 代替 Cython,实际上 Numba 的性能提升是巨大的(大约 20 倍加速)。但我很难将 Numba 与 python 类实例一起使用,因此我决定使用 Cython 而不是 Numba)。

供参考,以下是cython对其编译的注解:

【问题讨论】:

  • 您是否对代码进行了基准测试以找到瓶颈?通过快速阅读,对我来说,Cython 或 Numba 是否能够提供很大的加速并不是很明显:您的大部分操作已经以矢量化方式完成。我会先使用line profiler 找出慢点在哪里。
  • @jakevdp 感谢您的评论。我查找使用 line profiler,但似乎我首先需要学习如何在 Cython 和 Python3 上使用它......这需要一些时间。如果有帮助,我添加了带有注释模式的 Cython 编译结果。
  • 我认为他建议在您的原始/非 cython 代码上使用 line profiler 来查看哪些操作很慢。如果这些是基本的 numpy-primitives / 矢量化部分,你知道 cython 不会有帮助。
  • 您可能会一直在 odeint 中使用?最好的选择是进行多处理?我想剩下的会很快?
  • 是的,我建议在将 Python 代码转换为 Cython 之前对其进行分析。我怀疑瓶颈出现在 cython 和 numba 帮助有限的地方。

标签: python numpy scipy cython cythonize


【解决方案1】:

这是一个非常不完整的答案,因为我没有对任何内容进行分析或计时,甚至没有检查它是否给出了相同的答案。但是,这里有一些建议可以减少 Cython 生成的 Python 代码量:

  • 添加 @cython.cdivision(True) 编译指令。这意味着 ZeroDivisionError 不会在浮点除法上引发,而是您将获得 NaN 值。 (仅当您不想引发错误时才这样做)。

  • p_s = np.sqrt(...) 更改为p_s = sqrt(...)。这将删除仅对单个值进行操作的 numpy 调用。你好像在别处做过,所以我不知道你为什么错过了这条线。

  • 尽可能使用固定大小的 C 数组而不是 numpy 数组:

    cdef double beta[3]
    # ...
    beta[0] = X[1]/X[5]
    beta[1] = X[3]/X[5]
    beta[2] = p_s/X[5]
    

    当大小在编译时已知(并且相当小)并且您不想返回它时,您可以这样做。这避免了对np.zeros 的调用和一些后续的类型检查以将其分配给类型化的numpy 数组。我认为beta 是您唯一可以这样做的地方。

  • np.angle( complex(x/r, y/r) ) 可以替换为atan2(y/r, x/r)(使用atan2 from libc.math。你也可以通过r 失去除法

  • cdef int i 有助于使您的 for 循环在 getEMField 中更快(Cython 通常擅长自动提取循环变量的类型,但在这里似乎失败了)

  • 我怀疑逐个元素分配E 比分配整个数组更快:

            E[0] = k2*x/0.045 + k10*rn**9*cos(9*arg)
            E[1] = -k2*y/0.045 -k10*rn**9*sin(9*arg)
    
  • 指定listtuple 之类的类型没有多大价值,实际上它可能会使代码稍微慢一些(因为检查类型会浪费时间)。

  • 更大的变化是将EB 作为指针传递给GetEMField,而不是使用分配它们np.zeros。这将允许您将它们分配为equations (cdef double E[3]) 中的静态 C 数组。缺点是GetEMField 必须是cdef,因此不能再从 Python 调用(但如果你愿意,你也可以创建一个 Python 可调用的包装函数)。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2016-11-02
    • 1970-01-01
    • 2018-09-17
    • 2023-02-09
    相关资源
    最近更新 更多