我现在可以自己回答这个问题了。我完成了插值算法。作为ev-br的回答的建议,我只是为cupy重写了scipy.interpolate.CubicSpline。
scipy.interpolate.CubicSpline 包含很多函数,如果我们只需要一个插值函数,这将无济于事。
类CubicSpline有一个父类scipy.interpolate.PPoly,如果你只想要一个插值函数,它也包含了很多不必要的函数。清理干净后,我只使用了 _PPolyBase、solve_banded() 和 prepare_input() 类。
最难的部分是用 cython 编写的函数evaluate()。 Cython 不支持 cupy,所以我使用支持 cuda 的 numba 来加速循环的速度。
函数evaluate()的头部应该是这样的:
@cuda.jit('void(complex128[:,:,:], float64[:], float64[:], complex128[:,:])')
def evaluate(c, x, xp, out):
需要注意的重要一点是评估函数不是线程安全函数。
只有evaluate() 中的第一个循环是:
for ip in range(len(xp)):
xval = xp[IP]
......
可以使用cuda.grid(1)和cuda.gridsize(1)
另外,我将evaluate_poly1() 和find_interval_descending() 结合在evaluate() 中,以更好地适应numbe 的cuda 支持。
速度超级快,比原来的scipy函数快3到4倍左右。
代码可以在这里找到:https://github.com/GavinJiacheng/Interpolation_CUPY