【发布时间】:2018-03-20 14:20:33
【问题描述】:
我一直在尝试使用 numba 进行一些自动生成/ jit 的功能。
您可以从 jit 函数中调用其他 jit 函数,因此如果您有一组特定的函数,可以很容易地在我想要的功能中进行硬编码:
from numba import jit
@jit(nopython=True)
def f1(x):
return (x - 2.0)**2
@jit(nopython=True)
def f2(x):
return (x - 5.0)**2
def hardcoded(x, c):
@jit(nopython=True)
def f(x):
return c[0] * f1(x) + c[1] * f2(x)
return f
lincomb = hardcoded(3, (0.5, 0.5))
print(lincomb(2))
Out: 4.5
但是,假设您事先不知道 f1、f2 是什么。我希望能够使用一个工厂来生成函数,然后让另一个来生成它们的线性组合:
def f_factory(x0):
@jit(nopython=True)
def f(x):
return (x - x0)**2
return f
def linear_comb(funcs, coeffs, nopython=True):
@jit(nopython=nopython)
def lc(x):
total = 0.0
for f, c in zip(funcs, coeffs):
total += c * f(x)
return total
return lc
并在运行时调用它。这可以在没有 nopython 模式的情况下工作:
funcs = (f_factory(2.0), f_factory(5.0))
lc = linear_comb(funcs, (0.5, 0.5), nopython=False)
print(lc(2))
Out: 4.5
但不是 nopython 模式。
lc = linear_comb(funcs, (0.5, 0.5), nopython=True)
print(lc(2))
TypingError: Failed at nopython (nopython frontend)
Untyped global name 'funcs': cannot determine Numba type of <class 'tuple'>
File "<ipython-input-100-2d3fb6214044>", line 11
所以看起来 numba 对 jit 函数的元组有问题。有没有办法让这种行为起作用?
函数和 c 的集合可能会变大,所以我真的很想让它在 nopython 模式下编译。
【问题讨论】:
-
有什么理由不将
x0作为f的参数并删除该因素? -
在实际代码中,f 可以更加任意。如果 x0 是函数的参数,那么线性组合需要知道每个 f 的参数,这可能是不同的。我想说,例如有 f(x, x0, alpha, beta) 和 f2(x, x0, bool_flag),并且有闭包封装,所以它们看起来是 f(x)
-
知道了,我想你可能会因为当前的 numba 限制而走运,没有做一些非常丑陋的代码生成。
-
我也很害怕。有没有办法为 jitted 函数生成名称,然后代码生成一个按名称或其他方式调用它们的函数?这不会是最糟糕的事情。