【发布时间】:2022-01-25 00:35:06
【问题描述】:
我刚开始学习 numba,并在此处进行了练习,以确定解决简单矩阵问题所需的时间。我的目标是使用 python numba 库实现该程序的并行执行该程序有一个函数 create_matrix(row: int, col: int) 接受两个输入并创建一个二维矩阵。然后我创建两个矩阵,找到它们的主对角线之和并计算它们的总数。 问题是 numba 似乎不理解 numpy 二维数组。任何帮助将不胜感激。谢谢
#imports
from numba import njit
import numpy as np
#create a 2D matrix by consecutive numbers - starting 1
def create_matrix(row, col):
arr = np.array([[j + (col * i) for j in range(1, col + 1)] for i in range(row)])
return np.matrix(arr)
# calculate the sum of primary diagonals of matrix1
jitted_function = njit()(create_matrix)
m1 = jitted_function(1, 1)
print(f"Matrix 1 : {m1}")
print(f"Matrix 1 diagonal: {np.diagonal(m1)}")
print(f"Matrix 1 sum of primary diagonal is : {np.trace(m1)}")
mat1_sum = np.trace(m1, dtype='i')
# calculate the sum of primary diagonals of matrix2
m2 = create_matrix(4, 4)
print(f"Matrix 2 : {m2}")
print(f"Matrix 2 diagonal : {np.diagonal(m2)}")
print(f"Matrix 2 Sum of diagonal is : {np.trace(m2)}")
mat2_sum = np.trace(m2, dtype='i')
sum_of_two_diagonals = mat1_sum + mat2_sum
print(f"THE SUM IS : {sum_of_two_diagonals}")
错误是
Traceback (most recent call last):
File "E:\Users\SoniTech\PycharmProjects\computer_hardware\practise.py", line 21, in <module>
m1 = jitted_function(1, 1)
File "E:\Users\SoniTech\PycharmProjects\computer_hardware\venv\lib\site-packages\numba\core\dispatcher.py", line 468, in _compile_for_args
error_rewrite(e, 'typing')
File "E:\Users\SoniTech\PycharmProjects\computer_hardware\venv\lib\site-packages\numba\core\dispatcher.py", line 409, in error_rewrite
raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
>>> setitem(array(undefined, 1d, C), int64, array(int64, 1d, C))
There are 16 candidate implementations:
- Of which 16 did not match due to:
Overload of function 'setitem': File: <numerous>: Line N/A.
With argument(s): '(array(undefined, 1d, C), int64, array(int64, 1d, C))':
No match.
During: typing of setitem at E:\Users\SoniTech\PycharmProjects\computer_hardware\practise.py (10)
File "practise.py", line 10:
def create_matrix(row, col):
<source elided>
"""
arr = np.array([[j + (col * i) for j in range(1, col + 1)] for i in range(row)]) # create a matrix starting 1
^
【问题讨论】: