【问题标题】:How do I pass in calculated values to a list sort using numba.jit in python?如何在 python 中使用 numba.jit 将计算值传递给列表排序?
【发布时间】:2020-12-26 21:10:46
【问题描述】:

我正在尝试使用 Python 中的 numba-jit 函数中的自定义键对列表进行排序。简单的自定义键可以工作,例如我知道我可以使用这样的东西按绝对值排序:

import numba

@numba.jit(nopython=True)
def myfunc():
    mylist = [-4, 6, 2, 0, -1]
    mylist.sort(key=lambda x: abs(x))
    return mylist  # [0, -1, 2, -4, 6]

但是,在以下更复杂的示例中,我收到了一个我不理解的错误。

import numba
import numpy as np


@numba.jit(nopython=True)
def dist_from_mean(val, mu):
    return abs(val - mu)

@numba.jit(nopython=True)
def func():
    l = [1,7,3,9,10,-4,-2,0]
    avg_val = np.array(l).mean()
    l.sort(key=lambda x: dist_from_mean(x, mu=avg_val))
    return l

它报告的错误如下:

Traceback (most recent call last):
  File "testitout.py", line 18, in <module>
    ret = func()
  File "/.../python3.6/site-packages/numba/core/dispatcher.py", line 415, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/.../python3.6/site-packages/numba/core/dispatcher.py", line 358, in error_rewrite
    reraise(type(e), e, None)
  File "/.../python3.6/site-packages/numba/core/utils.py", line 80, in reraise
    raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: convert make_function into JIT functions)
Cannot capture the non-constant value associated with variable 'avg_val' in a function that will escape.

File "testitout.py", line 14:
def func():
    <source elided>
    l.sort(key=lambda x: dist_from_mean(x, mu=avg_val))
                                                ^

你知道这里发生了什么吗?

【问题讨论】:

    标签: python jit numba


    【解决方案1】:

    你知道这里发生了什么吗?

    通过使用参数nopython = True,您可以停用对象模式,因此 Numba 无法将所有值作为 Python 对象处理(请参阅:https://numba.pydata.org/numba-doc/latest/glossary.html#term-object-mode)。 (参考其实是我今天碰巧写的另一个帖子:How call a `@guvectorize` inside a `@guvectorize` in numba?

    @numba.jit(nopython=True)
    def func():
        l = [1,7,3,9,10,-4,-2,0]
        avg_val = np.array(l).mean()
        l.sort(key=lambda x: dist_from_mean(x, mu=avg_val))
        return l
    

    无论如何,lambda 对于 numba jit 函数来说“太”复杂了——至少当它作为参数传递时(比较 https://github.com/numba/numba/issues/4481)。激活nopython 模式后,您只能使用有限数量的库 - 完整列表可在此处找到:https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html

    这就是它抛出以下错误的原因:

    numba.core.errors.TypingError:在 nopython 模式管道中失败(步骤: 将 make_function 转换为 JIT 函数)无法捕获 与函数中的变量“avg_val”关联的非常量值 那会逃跑的。

    此外,当您拥有nopython = True 时,您在另一个中引用了一个 jit 加速函数。这也可能是问题的根源。

    我强烈建议您查看以下教程:http://numba.pydata.org/numba-doc/latest/user/5minguide.html#will-numba-work-for-my-code; 它应该可以帮助您解决类似问题!


    进一步阅读和来源:

    【讨论】:

    • 那么您如何建议在 jit 函数中使用这样的函数对列表进行自定义排序?
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2018-10-21
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2015-02-05
    • 2016-12-25
    相关资源
    最近更新 更多