【问题标题】:Why is numba library in python not recognizing numpy 2D array为什么python中的numba库无法识别numpy 2D数组
【发布时间】:2022-01-25 00:35:06
【问题描述】:

我刚开始学习 numba,并在此处进行了练习,以确定解决简单矩阵问题所需的时间。我的目标是使用 python numba 库实现该程序的并行执行该程序有一个函数 create_matrix(row: int, col: int) 接受两个输入并创建一个二维矩阵。然后我创建两个矩阵,找到它们的主对角线之和并计算它们的总数。 问题是 numba 似乎不理解 numpy 二维数组。任何帮助将不胜感激。谢谢

#imports
from numba import njit
import numpy as np

#create a 2D matrix by consecutive numbers - starting 1
def create_matrix(row, col):
    arr = np.array([[j + (col * i) for j in range(1, col + 1)] for i in range(row)])

    return np.matrix(arr)

# calculate the sum of primary diagonals of matrix1
jitted_function = njit()(create_matrix)
m1 = jitted_function(1, 1)
print(f"Matrix 1 : {m1}")
print(f"Matrix 1 diagonal: {np.diagonal(m1)}")
print(f"Matrix 1 sum of primary diagonal is : {np.trace(m1)}")
mat1_sum = np.trace(m1, dtype='i')


# calculate the sum of primary diagonals of matrix2
m2 = create_matrix(4, 4)
print(f"Matrix 2 : {m2}")
print(f"Matrix 2 diagonal : {np.diagonal(m2)}")
print(f"Matrix 2 Sum of diagonal is : {np.trace(m2)}")
mat2_sum = np.trace(m2, dtype='i')

sum_of_two_diagonals = mat1_sum + mat2_sum
print(f"THE SUM IS :  {sum_of_two_diagonals}")

错误是

Traceback (most recent call last):
  File "E:\Users\SoniTech\PycharmProjects\computer_hardware\practise.py", line 21, in <module>
    m1 = jitted_function(1, 1)
  File "E:\Users\SoniTech\PycharmProjects\computer_hardware\venv\lib\site-packages\numba\core\dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "E:\Users\SoniTech\PycharmProjects\computer_hardware\venv\lib\site-packages\numba\core\dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
 
 >>> setitem(array(undefined, 1d, C), int64, array(int64, 1d, C))
 
There are 16 candidate implementations:
   - Of which 16 did not match due to:
   Overload of function 'setitem': File: <numerous>: Line N/A.
     With argument(s): '(array(undefined, 1d, C), int64, array(int64, 1d, C))':
    No match.

During: typing of setitem at E:\Users\SoniTech\PycharmProjects\computer_hardware\practise.py (10)

File "practise.py", line 10:
def create_matrix(row, col):
    <source elided>
    """
    arr = np.array([[j + (col * i) for j in range(1, col + 1)] for i in range(row)])  # create a matrix starting 1
    ^

【问题讨论】:

    标签: python arrays numpy numba


    【解决方案1】:

    njit 模式不支持您的某些功能,例如np.matrix()。您可以改写 create_matrix() 函数,以便 numba 可以委托给它自己的函数。

    #imports
    from numba import njit
    import numpy as np
    
    #create a 2D matrix by consecutive numbers - starting 1
    def create_matrix(row, col):
        arr = np.zeros((row, col))
        for i in range(row):
            for j in range(1, col + 1):
                arr[i,j-1] = j + (col * i)
        return arr
    
    # calculate the sum of primary diagonals of matrix1
    jitted_function = njit()(create_matrix)
    

    【讨论】:

      猜你喜欢
      • 2018-06-15
      • 2021-09-07
      • 1970-01-01
      • 2020-11-01
      • 2018-05-28
      • 1970-01-01
      • 2016-05-22
      • 2017-11-26
      • 2022-11-16
      相关资源
      最近更新 更多