【问题标题】:Optimizing while cycle with numba for error tolerance使用 numba 优化 while 循环以实现容错
【发布时间】:2020-04-22 17:25:09
【问题描述】:

我在使用 numba 进行优化时有疑问。我正在编写一个定点迭代来计算一个名为 gamma 的数组的值,它满足方程 f(gamma)=gamma。我正在尝试使用 python 包 Numba 优化此功能。好像如下。

@jit
def fixed_point(gamma_guess):
    for i in range(17):
        gamma_guess=f(gamma_guess)
    return gamma_guess

Numba 能够很好地优化这个功能,因为它知道它会执行多少次操作,17 次,而且运行速度很快。但是我需要控制我想要的伽玛的误差容限,我的意思是,一个伽玛和下一个通过定点迭代获得的伽玛的差异应该小于某个数字epsilon = 0.01,然后我尝试了

@jit
def fixed_point(gamma_guess):
    err=1000
    gamma_old=gamma_guess.copy()
    while(error>0.01):
        gamma_guess=f(gamma_guess)
        err=np.max(abs(gamma_guess-gamma_old))
        gamma_old=gamma_guess.copy()
    return gamma_guess

它也可以工作并计算所需的结果,但没有上一个实现那么快,它要慢得多。我认为这是因为 Numba 无法很好地优化 while 循环,因为我们不知道它何时会停止。有没有办法可以优化它并像上次实现一样快地运行?

编辑:

这是我正在使用的 f

from scipy import fftpack as sp
S=0.01
Amu=0.7
@jit 
def f(gammaa,z,zal,kappa):
    ka=sp.diff(kappa)
    gamma0=gammaa
    for i in range(N):
        suma=0
        for j in range(N):
            if (abs(j-i))%2 ==1:
                if((z[i]-z[j])==0):
                    suma+=(gamma0[j]/(z[i]-z[j]))   
        gamma0[i]=2.0*Amu*np.real(-(zal[i]/z[i])+zal[i]*(1.0/(2*np.pi*1j))*suma*2*h)+S*ka[i]
    return  gamma0

我总是使用np.ones(2048)*0.5 作为初始猜测,我传递给我的函数的其他参数是z=np.cos(alphas)+1j*(np.sin(alphas)+0.1)zal=-np.sin(alphas)+1j*np.cos(alphas)kappa=np.ones(2048)alphas=np.arange(0,2*np.pi,2*np.pi/2048)

【问题讨论】:

  • This 建议,浮点比较在numba 中要慢得多,所以也许您必须将浮点比较转换为整数 1
  • 我尝试修改我的函数只做整数比较,但似乎没有帮助
  • 如果你使用gamma_old=gamma_guess 而不是gamma_old=gamma_guess.copy() 会发生什么?我看不出抄袭的理由。今晚晚些时候我可能有时间自己尝试一些解决方案
  • 是的,抱歉,我养成了一个习惯,因为有时事情不起作用,它解决了一些问题,但这次删除它似乎没有任何改变。非常感谢您以后自己尝试。
  • 现在,我发布了我的答案;你为什么用np.max

标签: python python-3.x numpy numba


【解决方案1】:

我做了一个小测试脚本,看看能不能重现你的错误:

import numba as nb

from IPython import get_ipython
ipython = get_ipython()

@nb.jit(nopython=True)
def f(x):
    return (x+1)/x


def fixed_point_for(x):
    for _ in range(17):
        x = f(x)
    return x

@nb.jit(nopython=True)
def fixed_point_for_nb(x):
    for _ in range(17):
        x = f(x)
    return x

def fixed_point_while(x):
    error=1
    x_old = x
    while error>0.01:
        x = f(x)
        error = abs(x_old-x)
        x_old = x
    return x

@nb.jit(nopython=True)
def fixed_point_while_nb(x):
    error=1
    x_old = x
    while error>0.01:
        x = f(x)
        error = abs(x_old-x)
        x_old = x
    return x

print("for loop without numba:")
ipython.magic("%timeit fixed_point_for(10)")

print("for loop with numba:")
ipython.magic("%timeit fixed_point_for_nb(10)")

print("while loop without numba:")
ipython.magic("%timeit fixed_point_while(10)")

print("for loop with numba:")
ipython.magic("%timeit fixed_point_while_nb(10)")

由于我不知道您的f,我只是使用了我能想到的最简单的稳定功能。然后,我使用numba 和不使用numba 运行测试,两次都使用forwhile 循环。我机器上的结果是:

for loop without numba:
3.35 µs ± 8.72 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
for loop with numba:
282 ns ± 1.07 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
while loop without numba:
1.86 µs ± 7.09 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
for loop with numba:
214 ns ± 1.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

出现以下想法:

  • 不可能,您的函数不可优化,因为您的 for 循环很快(至少您是这么说的;您是否在没有 numba 的情况下进行了测试?)。
  • 您的函数可能需要更多的循环来收敛,正如您所想的那样
  • 我们使用不同的软件版本。我的版本是:
    • numba 0.49.0
    • numpy 1.18.3
    • python 3.8.2

【讨论】:

  • 我忘了说我的函数的输入是一个numpy数组,我的f输出一个包含我的不动点的数组。
  • 不,它所用的时间远少于我放在那里的 17 个,我确信它会收敛,因为我已经对其进行了测试,并且它给出了与 while 版本相同的结果,但要快得多。跨度>
  • 那我会说,如果你编辑你的问题并给出这样一个f和相应的gamma猜测的例子,那将是最好的
  • 我不得不说,我绝不是这方面的专家,只是想帮助和理解这个问题。你试过nopython=True 并得到一个错误顺便说一句?
  • 我将使用我正在使用的函数 f 编辑我的问题,感谢您的帮助
猜你喜欢
  • 2018-05-20
  • 1970-01-01
  • 2016-03-18
  • 2023-03-18
  • 1970-01-01
  • 1970-01-01
  • 2017-01-07
  • 1970-01-01
  • 2011-05-14
相关资源
最近更新 更多