【问题标题】:Numba njit compiler causes computes different numbers compared to plain Python code?与纯 Python 代码相比,Numba njit 编译器会导致计算不同的数字?
【发布时间】:2021-12-29 14:05:47
【问题描述】:

我在 Python 中遇到了 Numba 的 njit 工具的问题。我注意到该函数在使用 @numba.njit 运行和作为纯 Python 代码运行时会给出不同的结果。特别是在调试之后,我注意到使用 numpy 执行矩阵求逆时会出现计算差异。请在下面查看我的测试代码。矩阵 A 和向量 b 的值位于以下 csv 文件中,可通过以下链接访问: A.csvb.csv

普通 Python 函数的结果是正确的。请帮我解决这个问题!我是否需要在 numpy 矩阵求逆函数周围使用 Numba 包装函数来解决似乎是数字问题的问题?

亲切的regrads,我期待很快收到你们的消息:)

艾哈迈德

@numba.njit
def cal_Test_jit(A,b):
    c = np.linalg.inv(A)@b
    return c, np.linalg.inv(A)

def cal_Test(A,b):
    c = np.linalg.inv(A)@b
    return c, np.linalg.inv(A)

A = np.loadtxt(open("A.csv", "rb"), delimiter=",")
b = np.loadtxt(open("b.csv", "rb"), delimiter=",")

c_jit, Ai_jit = cal_Test_jit(A,b)
c, Ai = cal_Test(A,b)
err_c = abs(c-c_jit)
err_A = abs(Ai_jit-Ai)

# ploting the error in the parameters
plt.figure()
plt.plot(err_c)

# only ploting the error in first three columns of A
fig, ax = plt.subplots(1,3)
ax[0].plot(err_A[:,0])
ax[1].plot(err_A[:,1])
ax[2].plot(err_A[:,2])

【问题讨论】:

    标签: python numpy numba jit


    【解决方案1】:

    在您的问题中使用 numba 的一种方法是添加:

    @numba.jit(forceobj=True)
    

    这将得到真实的结果,但执行时间比较为using njit (different results) > this method (exact results) == plain Python e.g.使用 colab TPU:

    1000 loops, best of 5: 545 µs per loop     # using njit
    1000 loops, best of 5: 505 µs per loop     # this method
    1000 loops, best of 5: 500 µs per loop     # plain Python
    

    但正如之前在另一个 SO question 上推荐的那样,Numba 对于优化纯 Python 的子集非常有用,尤其是循环,接近优化 C 代码的性能 @BatWannaBe 和 几乎总是有比反转矩阵更好的方法,例如np.linalg.solve @Humer512;这在另一个 SO question @ali_m 上有很好的解释。

    【讨论】:

    • 可爱。非常感谢@Ali_Sh 的及时回复。我使用了 np.linalg.solve,它可以工作。这会让我继续前进。稍后,我将查看应用程序标志 forceobj=True 和您引用的附加帖子。似乎 np.linalg.solve 类似于 MATLAB 的 A\b。再次感谢 Ali_Sh :)
    猜你喜欢
    • 2014-02-23
    • 2015-06-15
    • 1970-01-01
    • 2012-11-04
    • 1970-01-01
    • 2021-03-10
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多