我也在寻找解决方案。不幸的是,@Carbon 的建议不起作用,因为 numba.typeof 为函数 bar 返回的类型与函数 baz 的类型不同,即使 bar 和 baz 的签名相同.
例子:
import numba
@numba.jit(
numba.int32(numba.int32),
nopython=True,
nogil=True,
)
def bar(a):
return 2 * a
@numba.jit(
numba.int32(numba.int32),
nopython=True,
nogil=True,
)
def baz(a):
return 3 * a
@numba.jit(
numba.int32(numba.typeof(bar), numba.int32),
nopython=True,
nogil=True,
)
def foo(fn, a):
return fn(a)
foo(bar, 2) 返回 4
foo(baz, 2) 返回以下异常:
Traceback (most recent call last):
File "test_numba.py", line 33, in <module>
print(foo(baz, 2))
File "<snip>\Python38\lib\site-packages\numba\core\dispatcher.py", line 656, in _explain_matching_error
raise TypeError(msg)
TypeError: No matching definition for argument type(s) type(CPUDispatcher(<function baz at 0x000001DFA8C2D1F0>)), int64
我发现的唯一解决方法是完全省略 foo 的函数签名,让 numba 弄清楚。我不知道有什么负面后果(如果有的话)可能会让你的代码运行。
例子:
import numba
@numba.jit(
numba.int32(numba.int32),
nopython=True,
nogil=True,
)
def bar(a):
return 2 * a
@numba.jit(
numba.int32(numba.int32),
nopython=True,
nogil=True,
)
def baz(a):
return 3 * a
@numba.jit(
nopython=True,
nogil=True,
)
def foo(fn, a):
return fn(a)
foo(bar, 2) 返回 4
foo(baz, 2) 返回 6