【发布时间】:2021-12-29 14:05:47
【问题描述】:
我在 Python 中遇到了 Numba 的 njit 工具的问题。我注意到该函数在使用 @numba.njit 运行和作为纯 Python 代码运行时会给出不同的结果。特别是在调试之后,我注意到使用 numpy 执行矩阵求逆时会出现计算差异。请在下面查看我的测试代码。矩阵 A 和向量 b 的值位于以下 csv 文件中,可通过以下链接访问: A.csv 和 b.csv
普通 Python 函数的结果是正确的。请帮我解决这个问题!我是否需要在 numpy 矩阵求逆函数周围使用 Numba 包装函数来解决似乎是数字问题的问题?
亲切的regrads,我期待很快收到你们的消息:)
艾哈迈德
@numba.njit
def cal_Test_jit(A,b):
c = np.linalg.inv(A)@b
return c, np.linalg.inv(A)
def cal_Test(A,b):
c = np.linalg.inv(A)@b
return c, np.linalg.inv(A)
A = np.loadtxt(open("A.csv", "rb"), delimiter=",")
b = np.loadtxt(open("b.csv", "rb"), delimiter=",")
c_jit, Ai_jit = cal_Test_jit(A,b)
c, Ai = cal_Test(A,b)
err_c = abs(c-c_jit)
err_A = abs(Ai_jit-Ai)
# ploting the error in the parameters
plt.figure()
plt.plot(err_c)
# only ploting the error in first three columns of A
fig, ax = plt.subplots(1,3)
ax[0].plot(err_A[:,0])
ax[1].plot(err_A[:,1])
ax[2].plot(err_A[:,2])
【问题讨论】: