【问题标题】:Using Scipy curve_fit with piecewise function将 Scipy curve_fit 与分段函数一起使用
【发布时间】:2017-01-13 19:15:52
【问题描述】:

我收到优化警告:

OptimizeWarning: Covariance of the parameters could not be estimated
                 category=OptimizeWarning)

当尝试使用scipy.optimize.curve_fit 将我的分段函数拟合到我的数据时。意味着没有拟合发生。我可以轻松地将抛物线拟合到我的数据中,并且我正在为curve_fit 提供我认为很好的初始参数。下面的完整代码示例。有谁知道为什么curve_fit 可能无法与np.piecewise 相处?还是我犯了另一个错误?

import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt


def piecewise_linear(x, x0, y0, k1, k2):
    y = np.piecewise(x, [x < x0, x >= x0],
                     [lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0])
    return y

def parabola(x, a, b):
    y = a * x**2 + b
    return y

x = np.array([-3, -2, -1, 0, 1, 2, 3])
y = np.array([9.15, 5.68, 2.32, 0.00, 2.05, 5.29, 8.62])


popt_piecewise, pcov = curve_fit(piecewise_linear, x, y, p0=[0.1, 0.1, -5, 5])
popt_parabola, pcov = curve_fit(parabola, x, y, p0=[1, 1])

new_x = np.linspace(x.min(), x.max(), 61)


fig, ax = plt.subplots()

ax.plot(x, y, 'o', ls='')
ax.plot(new_x, piecewise_linear(new_x, *popt_piecewise))
ax.plot(new_x, parabola(new_x, *popt_parabola))

ax.set_xlim(-4, 4)
ax.set_ylim(-2, 16)

【问题讨论】:

    标签: python numpy scipy curve-fitting


    【解决方案1】:

    这是类型的问题,您必须更改以下行,以便x 以浮点数形式给出:

    x = np.array([-3, -2, -1, 0, 1, 2, 3]).astype(np.float)
    

    否则piecewise_linear 可能最终会转换类型。

    为了安全起见,您也可以在此处设置初始点:

    popt_piecewise, pcov = curve_fit(piecewise_linear, x, y, p0=[0.1, 0.1, -5., 5.])
    

    【讨论】:

    • 你是怎么得出这个结论的?
    • 我试图用给定的数据点评估piecewise_linear,但没有奏效,所以我得出结论,问题一定出在某个地方。我认为这与np.piecewise 的一些奇怪行为有关。
    • 我尝试了相同的方法,但完全错过了。很好!
    • 我建议 x = np.array([-3, -2, -1, 0, 1, 2, 3], dtype=np.float) 立即告诉 NumPy 构造一个浮点数组,而不是构造一个整数数组然后转换类型。
    • @zaq 是的,这样更好。
    【解决方案2】:

    为了完整起见,我将指出拟合分段线性函数不需要np.piecewise:任何此类函数都可以由绝对值构造而成,每个弯曲使用np.abs(x-x0) 的倍数。以下产生了与数据的良好拟合:

    def pl(x, x0, a, b, c):
        y = a*np.abs(x-x0) + b*x + c
        return y
    
    popt_pl, pcov = curve_fit(pl, x, y, p0=[0, 0, 0, 0])
    
    print(pl(x, *popt_pl))
    

    输出接近原始 y 值:

    [ 8.90899998  5.828       2.74700002 -0.33399996  2.03499998  5.32
      8.60500002]
    

    【讨论】:

      猜你喜欢
      • 2021-08-11
      • 2016-03-08
      • 1970-01-01
      • 2020-05-07
      • 2021-02-16
      • 1970-01-01
      • 2016-09-16
      • 2017-04-11
      • 2020-11-09
      相关资源
      最近更新 更多