【问题标题】:Gradient descent in linear regression goes wrong线性回归中的梯度下降出错
【发布时间】:2015-03-09 12:53:28
【问题描述】:

我实际上想使用线性模型来拟合一组“罪”数据,但事实证明,损失函数在每次迭代期间都会变大。我下面的代码有问题吗? (梯度下降法)

这是我在 Matlab 中的代码

m=20;
rate = 0.1;
x = linspace(0,2*pi,20);
x = [ones(1,length(x));x]
y = sin(x);
w = rand(1,2);
for i=1:500
    h = w*x;
    loss = sum((h-y).^2)/m/2 
    total_loss = [total_loss loss];
    **gradient = (h-y)*x'./m ;**
    w = w - rate.*gradient;
end

这是我想要拟合的数据

【问题讨论】:

  • 请改进您的代码。事实上,它没有运行,因为 m、rate、total_loss 没有定义。也许你甚至可以让它变得更短,更优雅。只包括看到问题所必需的最低限度。
  • 对正弦数据进行线性回归(即斜率和截距)可能不会得到正确的结果。如果您想拟合该曲线,您可能需要添加高阶项(二阶、三阶等)。我现在可能会告诉你,你不会得到准确的结果。
  • @JohnGreen - 没有问题。实际上,当我运行您的代码时,它会收敛,但损失项相对较大……嗯,我得到 0.1157。如果要拟合正弦数据,则需要包含高阶项。我会写一篇文章。
  • @JohnGreen - 是的,您需要增加代表曲线的术语数量。只用一条直线表示正弦数据肯定会给你带来很大的错误。
  • @JohnGreen - 这真的很奇怪!好的别担心。我快到这里了。

标签: matlab machine-learning gradient-descent


【解决方案1】:

您的代码没有问题。以你目前的框架,如果你能以y = m*x + b的形式定义数据,那么这段代码就绰绰有余了。实际上,我通过一些测试对其进行了测试,在这些测试中我定义了线的方程并向其添加了一些高斯随机噪声(幅度 = 0.1,平均值 = 0,标准偏差 = 1)。

但是,我要向您提及的一个问题是,如果您查看您的正弦数据,您会在[0,2*pi] 之间定义一个域。如您所见,您有多个 x 值映射到相同的 y 值但大小不同。例如,在 x = pi/2 我们得到 1,但在 x = -3*pi/2 我们得到 -1。这种高可变性对线性回归来说不是好兆头,所以我的一个建议是限制你的域......就像[0, pi]。它可能不收敛的另一个原因是您选择的学习率太高。我会将它设置为像0.01 这样的低值。正如您在 cmets 中提到的,您已经想通了!

但是,如果您想使用线性回归拟合非线性数据,则必须包含高阶项以解释可变性。因此,尝试包括二阶和/或三阶术语。这可以简单地通过修改您的 x 矩阵来完成,如下所示:

x = [ones(1,length(x)); x; x.^2; x.^3];

如果你回想一下,假设函数可以表示为线性项的总和:

h(x) = theta0 + theta1*x1 + theta2*x2 + ... + thetan*xn

在我们的例子中,每个theta 项都会构建我们多项式的高阶项。 x2 将是 x^2x3 将是 x^3。因此,这里我们仍然可以使用梯度下降的定义进行线性回归。

我还将控制随机生成种子(通过rng),以便您可以产生与我得到的相同的结果:

clear all; 
close all;
rng(123123);
total_loss = [];
m = 20;
x = linspace(0,pi,m); %// Change
y = sin(x);
w = rand(1,4); %// Change
rate = 0.01; %// Change
x = [ones(1,length(x)); x; x.^2; x.^3]; %// Change - Second and third order terms
for i=1:500
    h = w*x;
    loss = sum((h-y).^2)/m/2;
    total_loss = [total_loss loss];
    % gradient is now in a different expression
    gradient = (h-y)*x'./m ; % sum all in each iteration, it's a batch gradient
    w = w - rate.*gradient;
end

如果我们尝试这个,我们会得到w(你的参数):

>> format long g;
>> w


w =

  Columns 1 through 3

         0.128369521905694         0.819533906064327       -0.0944622478526915

  Column 4

       -0.0596638117151464

此时我的最终损失是:

loss =

       0.00154350916582836

这意味着我们的直线方程是:

y = 0.12 + 0.819x - 0.094x^2 - 0.059x^3

如果我们用你的正弦数据绘制这条线的方程,这就是我们得到的:

xval = x(2,:);
plot(xval, y, xval, polyval(fliplr(w), xval))
legend('Original', 'Fitted');

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-06-20
    • 2019-10-09
    • 2016-10-22
    • 2017-01-02
    • 2021-06-24
    • 1970-01-01
    相关资源
    最近更新 更多