【问题标题】:Numba Signature for jitted function as argumentjitted 函数的 Numba 签名作为参数
【发布时间】:2021-02-22 20:15:10
【问题描述】:

我查看了 numba 文档,但找不到任何东西。

我有一个 jit 函数,它以 jited_function 作为参数。我想通过添加签名来进行即时编译,就像:

@jit(float64('jit_func.type', int32, int32...))

'jitted_func.type'应该是“函数类型”

当我这样做时:

type(jitted_func)

我得到一个 CPUDispatcher 对象

感谢您的帮助!

【问题讨论】:

    标签: python signature jit numba


    【解决方案1】:

    我也在寻找解决方案。不幸的是,@Carbon 的建议不起作用,因为 numba.typeof 为函数 bar 返回的类型与函数 baz 的类型不同,即使 barbaz 的签名相同.

    例子:

    import numba 
    
    @numba.jit(
        numba.int32(numba.int32),
        nopython=True,
        nogil=True,
    )
    def bar(a):
    
        return 2 * a
    
    @numba.jit(
        numba.int32(numba.int32),
        nopython=True,
        nogil=True,
    )
    def baz(a):
    
        return 3 * a
    
    @numba.jit(
        numba.int32(numba.typeof(bar), numba.int32),
        nopython=True,
        nogil=True,
    )
    def foo(fn, a):
    
        return fn(a)
    

    foo(bar, 2) 返回 4

    foo(baz, 2) 返回以下异常:

    Traceback (most recent call last):
      File "test_numba.py", line 33, in <module>
        print(foo(baz, 2))
      File "<snip>\Python38\lib\site-packages\numba\core\dispatcher.py", line 656, in _explain_matching_error
        raise TypeError(msg)
    TypeError: No matching definition for argument type(s) type(CPUDispatcher(<function baz at 0x000001DFA8C2D1F0>)), int64
    

    我发现的唯一解决方法是完全省略 foo 的函数签名,让 numba 弄清楚。我不知道有什么负面后果(如果有的话)可能会让你的代码运行。

    例子:

    import numba 
    
    @numba.jit(
        numba.int32(numba.int32),
        nopython=True,
        nogil=True,
    )
    def bar(a):
    
        return 2 * a
    
    @numba.jit(
        numba.int32(numba.int32),
        nopython=True,
        nogil=True,
    )
    def baz(a):
    
        return 3 * a
    
    @numba.jit(
        nopython=True,
        nogil=True,
    )
    def foo(fn, a):
    
        return fn(a)
    

    foo(bar, 2) 返回 4

    foo(baz, 2) 返回 6

    【讨论】:

      【解决方案2】:

      所以,我不确定如何从头生成您正在寻找的签名,但如果您有一个带有您想要的签名的编译函数示例,您可以使用numba.typeof(...) 来获取预期的签名,考虑,例如:

      import numba
      
      @numba.njit(numba.int32(numba.int32))
      def x(a):
          return a+1
      
      @numba.njit(numba.int32(numba.typeof(x), numba.int32))
      def y(fn,a):
          return fn(a)
          
      print(y(x,3))
      

      这是急切的编译,我检查了。如果你想进一步解决这个问题,正确的起点是numba.core.types.functions,而Dispatcher类型是在编译时专门处理的,参见numba.core.typing.context.BaseContext._resolve_user_function_type

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2017-08-18
        • 1970-01-01
        • 2021-09-19
        • 1970-01-01
        • 1970-01-01
        • 2022-06-17
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多