【问题标题】:Recursive function in numba CUDAnumba CUDA 中的递归函数
【发布时间】:2019-09-02 22:00:20
【问题描述】:

numba 的文档指出:

numba 中的递归支持目前仅限于带有显式类型注释的函数的自递归。

我做了这个简单的设备功能:

@cu.jit(numba.i4(numba.i4), device=True)
def mutate(val: int) -> int:
    if(val < 1):
        return val
    else:
        return mutate(val-1)

这是一个相当简单的递归测试。现在从我的内核代码调用这个函数我得到Untyped global name 'mutate': cannot determine Numba type of <class 'numba.ir.UndefinedType'> 错误。 我还应该如何指定函数的类型?我该如何解决这个问题?

【问题讨论】:

    标签: python cuda numba


    【解决方案1】:

    numba 中的递归支持目前仅限于带有显式类型注释的函数的自递归。

    首先要说明的是,numba 文档包含您引用的文本。那来自Numba Enhancement Proposal 6。扩展 numba 工作方式的提议。不是语言/编译器的功能。

    也就是说,这个:

    import numba
    
    @numba.jit(numba.i4(numba.i4))
    def mutate(val: int) -> int:
        if(val < 1):
            return val
        else:
            return mutate(val-1)
    

    如您所料:

    In [12]: %run recursion.py
    
    In [13]: mutate??
    Signature:       mutate(val:int) -> int
    Call signature:  mutate(*args, **kwargs)
    Type:            CPUDispatcher
    String form:     CPUDispatcher(<function mutate at 0x7f77f787f840>)
    File:            ~/SO/recursion.py
    Source:         
    @numba.jit(numba.i4(numba.i4))
    def mutate(val: int) -> int:
        if(val < 1):
            return val
        else:
            return mutate(val-1)
    Class docstring:
    Implementation of user-facing dispatcher objects (i.e. created using
    the @jit decorator).
    This is an abstract base class. Subclasses should define the targetdescr
    class attribute.
    Init docstring: 
    Parameters
    ----------
    py_func: function object to be compiled
    locals: dict, optional
        Mapping of local variable names to Numba types.  Used to override
        the types deduced by the type inference engine.
    targetoptions: dict, optional
        Target-specific config options.
    impl_kind: str
        Select the compiler mode for `@jit` and `@generated_jit`
    pipeline_class: type numba.compiler.BasePipeline
        The compiler pipeline type.
    
    In [14]: print(mutate(10))
    0
    

    但是 Numba CUDA 编译器(我猜是在 nopython 模式下编译)不会编译等效代码,正如您所发现的那样。由此,再加上 Numba CUDA 文档中没有提到递归这一事实,我会得出结论,Numba CUDA 编译器不支持递归。

    【讨论】:

    • 感谢指正。但在我结束问题之前,您能否验证 numba CUDA jit 函数完全丢弃了 argType 参数?我阅读了一些代码,似乎该属性被忽略了
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2020-10-03
    • 2019-08-29
    • 2013-05-08
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多