【问题标题】:Optimizing root finding algorithm from scipy从 scipy 优化寻根算法
【发布时间】:2018-05-22 14:59:53
【问题描述】:

我在我的代码中使用来自scipy.optimizeroot 函数和方法“excitingmixing”,因为其他方法,如标准牛顿法,不会收敛到我正在寻找的根。

但是我想使用不支持scipy 包的numba 优化我的代码。我试图在文档中查找“令人兴奋的混合”算法来自己编程:

https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.root.html

我没有发现任何有用的东西,除了该方法“使用调整的对角雅可比近似”这一不太有用的声明。

如果有人能告诉我一些关于算法的事情,或者对如何以其他方式优化 scipy 函数有想法,我会很高兴。

这里要求的是一个最小的代码示例:

import numpy as np
from scipy import optimize
from numba import jit

@jit(nopython = True)
def func(x):
    [a, b, c, d] = x

    da = a*(1-b)
    db = b*(1-c)
    dc = c
    dd = 1

    return [da, db, dc, dd]

@jit(nopython = True)
def getRoot(x0):
    solution = optimize.root(func, x0, method="excitingmixing")
    return(solution.x)

root = getRoot([0.1,0.1,0.2,0.4])
print(root)

【问题讨论】:

  • 考虑运行分析器以确定消耗大部分计算时间的是ExcitingMixing优化器的开销,还是对目标函数的评估。如果是后者,您可以将目标函数移植到numba 并使用scipy 提供的标准算法。
  • 我很确定这不是我的职责。该函数已经过优化,大多数求根算法更快,但不会收敛到我正在寻找的根。
  • 提供一个完整的例子。许多 scipy 函数可以采用低级回调函数而不是 Python 函数。这是 scipy.integrate.quad stackoverflow.com/a/50097776/4045774 的示例

标签: python optimization scipy numba


【解决方案1】:

你可以在scipy的源码中查看excitingmixing选项的实现:

https://github.com/scipy/scipy/blob/c948e96ebb3454f6a82e9d14021cc601d7ce7a85/scipy/optimize/nonlin.py#L1272

您可能不想在 numba 中重新实现整个寻根算法。我可以测试的更好的策略是使用 numba 来优化传递给 scipy 方法的函数。您仍然需要为 scipy 调用函数支付一些开销,但如果瓶颈是评估函数,您可能会看到性能提高,并且可以使用 numba jitted 版本更快地完成。我发现最好只试验 numba 并使用 timeit 方法进行测试。

【讨论】:

  • 感谢您的快速回答 - 我已经优化了从 scipy 调用的函数,瓶颈实际上是 scipy 提供的带有“excitingmixing”选项的根查找算法。
【解决方案2】:

我写了一个小包装器 Minpack,名为 NumbaMinpack,可以在 numba 编译函数中调用:https://github.com/Nicholaswogan/NumbaMinpack

如果 Newton 的方法让您失败,您应该尝试 lmdif 方法。

from NumbaMinpack import lmdif, hybrd, minpack_sig
from numba import njit, cfunc
import numpy as np

@cfunc(minpack_sig)
def myfunc(x, fvec, args):
    fvec[0] = x[0]**2 - args[0]
    fvec[1] = x[1]**2 - args[1]
    
funcptr = myfunc.address # pointer to myfunc

x_init = np.array([10.0,10.0]) # initial conditions
neqs = 2 # number of equations
args = np.array([30.0,8.0]) # data you want to pass to myfunc
@njit
def test():
    # solve with lmdif
    sol = lmdif(funcptr, x_init, neqs, args)
    # OR solve with hybrd
    sol = hybrd(funcptr, x_init, args) 
    return sol
test() # it works!

【讨论】:

    猜你喜欢
    • 2017-03-11
    • 1970-01-01
    • 1970-01-01
    • 2022-01-24
    • 2023-03-25
    • 2010-12-15
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多