【问题标题】:Numba: calling jit with explicit signature using arguments with default valuesNumba:使用具有默认值的参数调用具有显式签名的 jit
【发布时间】:2023-07-25 21:33:01
【问题描述】:

我正在使用 numba 在 numpy 数组上创建一些包含循环的函数。

一切都很好,花花公子,我可以使用jit,并且我学会了如何定义签名。

现在我尝试在带有可选参数的函数上使用 jit,例如:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, optional(float))'])
def fun(a, b=3):
    return a + b

这可行,但如果我使用optional(float64) 而不是optional(float),它就不行(与intint64 相同)。我花了 1 个小时试图弄清楚这个语法(实际上,我的一个朋友偶然发现了这个解决方案,因为他忘记在浮动之后写 64),但是,为了我的爱,我不明白为什么会这样所以。我在互联网上找不到任何东西,而且 numba 的有关该主题的文档充其量也很少(他们指定 optional 应该采用 numba 类型)。

有人知道这是如何工作的吗?我错过了什么?

【问题讨论】:

  • 您需要'float64(float64, optional(float))' 部分吗?我怀疑你应该删除它。
  • 我不知道,我测试过它似乎运行得更快,但我可能误用了timeit。我知道它可以在没有显式签名的情况下工作,我只是想了解 numba 的工作原理

标签: python types jit numba


【解决方案1】:

啊,但是异常信息应该给出提示:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, optional(float64))'])
def fun(a, b=3.):
    return a + b

>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3.0)

这意味着optional 在这里是错误的选择。事实上optional represents None or "that type"。但是您需要一个可选参数,而不是可能是 floatNone 的参数,例如:

>>> fun(10, None)  # doesn't fail because of the signature!
TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'

我怀疑它只是“碰巧”为 optional(float) 工作,因为从 numbas 的角度来看,float 只是一个“任意 Python 对象”,所以使用 optional(float) 你可以传递 任何东西 在那里(这显然包括不给出论点)。对于optional(float64),它只能是Nonefloat64。该类别不够广泛,无法不提供论据

如果你给出类型Omitted,它就可以工作:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3.):
    return a + b

>>> fun(10.)
13.0

但是,Omitted 似乎实际上并未包含在文档中,并且它有一些“粗糙的边缘”。例如,它不能在 nopython 模式下使用该签名进行编译,即使它似乎没有签名也是可能的:

@njit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3):
    return a + b

TypingError: Failed at nopython (nopython frontend)
Invalid usage of + with parameters (float64, class(float64))

-----------

@njit(['float64(float64, float64)', 'float64(float64, Omitted(3.))'])
def fun(a, b=3):
    return a + b

>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3)

-----------

@njit
def fun(a, b=3):
    return a + b

>>> fun(10.)
13.0

【讨论】:

  • 谢谢!我看到了那个错误,并想到了这个 Omitted 的事情,但找不到任何东西,或者无论如何让它工作,现在我意识到我使用的是带有小写 o 的 omitted... 也非常感谢这些例子:)
  • 没问题。但我建议你不要打扰Omitted(和optional)。如果我正确解释文档,这些是更多的内部类型。如果您省略签名并使用cache=True(以避免多次编译),则效果最佳。
  • 使用签名不会提高性能吗?
  • 并非如此。签名更多的是“限制”可能的输入并预先编译函数。如果没有签名,它将在您调用函数时推断类型并在调用时编译它(如果需要)(即时 - jit)。因此,通过提供签名,您可以加快 first(并且仅是第一次)对特定输入函数的调用。
最近更新 更多