【发布时间】:2021-12-29 22:42:54
【问题描述】:
我正在使用 Numba 进行课程加速。当你想在类中使用 Numba 时,你必须定义/预分配你的类变量。在这方面我的问题是在 jitclass 之前声明一个二维数组。以下 MWE 直接显示了我的问题:
import numpy as np
from numba import int32, float32
from numba.experimental import jitclass # import the decorator
spec = [
('value', int32), # a simple scalar field
('array', float32[:]), # an array field
('foo_matrix',int32[:,:]),
]
@jitclass(spec)
class Bag(object):
def __init__(self, value):
self.value = value
self.array = np.zeros(value)
self.foo_matrix = np.zeros((value, value))
@property
def size(self):
return self.array.size
def increment(self, val):
for i in range(self.size):
self.array[i] = val
return self.array
my_class = Bag(3)
当我执行此代码时,我收到以下错误:
Traceback (most recent call last):
File "/home/acer/codici/tech/numba_prototype.py", line 38, in <module>
my_class = Bag(3)
File "/usr/lib/python3/dist-packages/numba/experimental/jitclass/base.py", line 122, in __call__
return cls._ctor(*bind.args[1:], **bind.kwargs)
File "/usr/lib/python3/dist-packages/numba/core/dispatcher.py", line 414, in _compile_for_args
error_rewrite(e, 'typing')
File "/usr/lib/python3/dist-packages/numba/core/dispatcher.py", line 357, in error_rewrite
raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Internal error at <numba.core.typeinfer.CallConstraint object at 0x7fc6d4945c40>.
Failed in nopython mode pipeline (step: nopython mode backend)
Can only insert float* at [4] in {i8*, i8*, i64, i64, float*, [1 x i64], [1 x i64]}: got double*
File "numba_prototype.py", line 19:
def __init__(self, value):
<source elided>
self.value = value
self.array = np.zeros(value)
^
During: lowering "(self).array = $14call_method.5" at /home/acer/codici/tech/numba_prototype.py (19)
During: resolving callee type: jitclass.Bag#7fc6d5a2afa0<value:int32,array:array(float32, 1d, A),foo_matrix:array(int32, 2d, A)>
During: typing of call at <string> (3)
Enable logging at debug level for details.
File "<string>", line 3:
<source missing, REPL/exec in use?>
这与矩阵foo_matrix的声明有关。
关于我遵循this的类型定义。
当然,如果我注释掉关于数组声明和填充的行,代码就可以正常工作。 我应该如何修改/声明关于 jitclass 对象的矩阵?
编辑:在类中,我已将 foo_matrix 的声明从 np.zeros([value, value]) 更改为 np.zeros((value, value)),因为使用列表而不是元组定义 numpy 数组可能是错误的来源numba 函数。但是,即使进行此修改,问题仍然存在。
【问题讨论】:
-
"在类中,我已将 foo_matrix 的声明从 np.zeros([value, value]) 更改为 np.zeros((value, value)" 那些不是声明. Python 并没有真正有 变量声明,除非你在谈论类型提示,所以你真的需要更准确地知道你在说什么关于。如果您遇到错误,发布错误消息
-
@juanpa.arrivillaga 我正在写我得到的错误,但关于声明,这是重点。我正在使用 Numba,因此您必须在使用之前声明变量的类型,就像 C 一样。
-
您将
foo_matrix声明为整数,但正如错误消息所述,您正在为其分配一个浮点数组。
标签: python-3.x performance type-conversion declaration numba