【发布时间】:2020-09-29 15:21:43
【问题描述】:
所以我的想法是(从神经网络的人那里借来的)如果我有数据集 D,我可以通过首先计算误差对参数(a、b 和 c)的导数来拟合二次曲线),然后做一个小的更新来最小化错误。我的问题是,下面的代码实际上并不适合曲线。对于线性的东西,类似的方法有效,但由于某种原因,二次似乎失败了。你能看到我做错了什么吗(假设或只是实现错误)
编辑:问题不够具体:以下代码不能很好地处理数据中的偏差。出于某种原因,它以某种方式更新了 a 和 b 参数,而 c 被抛在了后面。这种方法类似于机器人技术(使用雅可比矩阵查找路径)和神经网络(根据误差查找参数),所以它不是不合理的算法,现在的问题是,为什么这个特定的实现不会产生我期望的结果。
在下面的 Python 代码中,我使用数学作为 m,而 MSE 是一个计算两个数组之间的均方误差的函数。除此之外,代码是自包含的
代码:
def quadraticRegression(data, dErr):
a = 1 #Starting values
b = 1
c = 1
a_momentum = 0 #Momentum to counter steady state error
b_momentum = 0
c_momentum = 0
estimate = [a*x**2 + b*x + c for x in range(len(data))] #Estimate curve
error = MSE(data, estimate) #Get errors 'n stuff
errorOld = 0
lr = 0.0000000001 #learning rate
while abs(error - errorOld) > dErr:
#Fit a (dE/da)
deda = sum([ 2*x**2 * (a*x**2 + b*x + c - data[x]) for x in range(len(data)) ])/len(data)
correction = deda*lr
a_momentum = (a_momentum)*0.99 + correction*0.1 #0.99 is to slow down momentum when correction speed changes
a = a - correction - a_momentum
#fit b (dE/db)
dedb = sum([ 2*x*(a*x**2 + b*x + c - data[x]) for x in range(len(data))])/len(data)
correction = dedb*lr
b_momentum = (b_momentum)*0.99 + correction*0.1
b = b - correction - b_momentum
#fit c (dE/dc)
dedc = sum([ 2*(a*x**2 + b*x + c - data[x]) for x in range(len(data))])/len(data)
correction = dedc*lr
c_momentum = (c_momentum)*0.99 + correction*0.1
c = c - correction - c_momentum
#Update model and find errors
estimate = [a*x**2 +b*x + c for x in range(len(data))]
errorOld = error
print(error)
error = MSE(data, estimate)
return a, b, c, error
【问题讨论】:
-
堆栈溢出是针对特定的代码问题。找出你想出的算法有什么问题不在本网站的范围之内。请使用here 描述的原则来调试您的程序,并更新您的问题以使其更具体。
-
编辑问题更具体。代码有问题。
-
当然代码有问题。 更具体,我的意思是您应该进行一些调试并缩小问题的根源。 你认为你的代码有什么问题?尝试创建一个minimal reproducible example。按每个代码块的功能分解它,并尝试通过在使用调试器运行代码时单步执行代码来找出问题所在。再说一遍:“这是我的代码,告诉我为什么它不起作用”——一般调试问题是outside Stack Overflow's scope
-
那是百万美元的问题。该算法不是什么新东西(最小二乘拟合),而且这种特殊情况与我认为的一样小。存在一些问题,因为它不能很好地拟合曲线,例如偏差项 c 不会从初始位置移动,即使数据明显有偏差(所有点都远高于 0)。
-
@Nyxeria 我认为你的算法是正确的,我在my answer 中绘制了动画,我觉得它看起来不错。
标签: python curve-fitting data-fitting mse quadratic-curve