【问题标题】:Linear Regression with Gradient Descent梯度下降的线性回归
【发布时间】:2023-05-21 03:47:01
【问题描述】:

我编写了以下 Java 程序来实现带有梯度下降的线性回归。代码执行但结果不准确。 y 的预测值与 y 的实际值并不接近。例如,当 x = 75 时,预期的 y = 208 但输出为 y = 193.784。

class LinReg {

    double theta0, theta1;

    void buildModel(double[] x, double[] y) {
        double x_avg, y_avg, x_sum = 0.0, y_sum = 0.0;
        double xy_sum = 0.0, xx_sum = 0.0;
        int n = x.length, i;
        for( i = 0; i < n; i++ ) {
            x_sum += x[i];
            y_sum += y[i];
        }
        x_avg = x_sum/n;
        y_avg = y_sum/n;

        for( i = 0; i < n; i++) {
            xx_sum += (x[i] - x_avg) * (x[i] - x_avg);
            xy_sum += (x[i] - x_avg) * (y[i] - y_avg);
        }
        theta1 = xy_sum/xx_sum;
        theta0 = y_avg - (theta1 * x_avg);
        System.out.println(theta0);
        System.out.println(theta1);

        gradientDescent(x, y, 0.1, 1500);
    }

    void gradientDescent(double x[], double y[], double alpha, int maxIter) {
        double oldtheta0, oldtheta1;
        oldtheta0 = 0.0;
        oldtheta1 = 0.0;
        int n = x.length;
        for(int i = 0; i < maxIter; i++) {
            if(hasConverged(oldtheta0, theta0) && hasConverged(oldtheta1, theta1))
                break;
            oldtheta0 = theta0;
            oldtheta1 = theta1;
            theta0 = oldtheta0 - (alpha * (summ0(x, y, oldtheta0, oldtheta1)/(double)n));
            theta1 = oldtheta1 - (alpha * (summ1(x, y, oldtheta0, oldtheta1)/(double)n));
            System.out.println(theta0);
            System.out.println(theta1);


        }
    }

    double summ0(double x[], double y[], double theta0, double theta1) {
        double sum = 0.0;
        int n = x.length, i;
        for( i = 0; i < n; i++ ) {
            sum += (hypothesis(theta0, theta1, x[i]) - y[i]);
        }
        return sum;
    }

    double summ1(double x[], double y[], double theta0, double theta1) {
        double sum = 0.0;
        int n = x.length, i;
        for( i = 0; i < n; i++ ) {
            sum += (((hypothesis(theta0, theta1, x[i]) - y[i]))*x[i]);
        }
        return sum;
    }

    boolean hasConverged(double oldTheta, double newTheta) {
        return ((newTheta - oldTheta) < (double)0);
    }

    double predict(double x) {
        return hypothesis(theta0, theta1, x);
    }

    double hypothesis(double thta0, double thta1, double x) {
        return (thta0 + thta1 * x);
    }
}

public class LinearRegression {
    public static void main(String[] args) {
        //Height data
        double x[] = {63.0, 64.0, 66.0, 69.0, 69.0, 71.0, 71.0, 72.0, 73.0, 75.0};
        //Weight data
        double y[] = {127.0, 121.0, 142.0, 157.0, 162.0, 156.0, 169.0, 165.0, 181.0, 208.0};
        LinReg model = new LinReg();
        model.buildModel(x, y);
        System.out.println("----------------------");
        System.out.println(model.theta0);
        System.out.println(model.theta1);
        System.out.println(model.predict(75.0));
    }
}

【问题讨论】:

    标签: java machine-learning


    【解决方案1】:

    没有错。

    我在 R 中验证了解决方案:

    x <- c(63.0, 64.0, 66.0, 69.0, 69.0, 71.0, 71.0, 72.0, 73.0, 75.0)
    y <- c(127.0, 121.0, 142.0, 157.0, 162.0, 156.0, 169.0, 165.0, 181.0, 208.0)
    
    mod <- lm(y~x)
    summary(mod)
    
    Call:
    lm(formula = y ~ x)
    
    Residuals:
         Min       1Q   Median       3Q      Max 
    -13.2339  -4.0804  -0.0963   4.6445  14.2158 
    
    Coefficients:
                 Estimate Std. Error t value Pr(>|t|)    
    (Intercept) -266.5344    51.0320  -5.223    8e-04 ***
    x              6.1376     0.7353   8.347 3.21e-05 ***
    ---
    Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
    
    Residual standard error: 8.641 on 8 degrees of freedom
    Multiple R-squared:  0.897,   Adjusted R-squared:  0.8841 
    F-statistic: 69.67 on 1 and 8 DF,  p-value: 3.214e-05
    

    计算 X 值为 75 的 y-hat:

    -266.5344 +(6.1376 *75)
    

    [1] 193.784

    这是一个正确的预测。我认为混乱必须围绕回归的工作原理。回归不会告诉您训练数据中与给定独立数据点相对应的数据点的精确实际值。那只是一个字典,而不是一个统计模型(在这种情况下,它无法进行内插或外推)。

    回归将最小二乘法拟合到您的数据以估计模型方程,然后使用该方程在给定自变量值的情况下预测因变量的值。唯一能准确预测训练数据中数据点的情况是模型过拟合(这很糟糕)。

    更多信息和链接:

    https://en.wikipedia.org/wiki/Regression_analysis

    【讨论】:

    • 是的,我对回归和梯度下降的概念感到困惑。感谢您的回答。
    • @noober 乐于助人。如果你想要一个好的 ML 和回归 Java 库,我推荐 Weka。