【问题标题】:Problem in the linear regression implementation线性回归实现中的问题
【发布时间】:2020-05-24 17:57:19
【问题描述】:

我是机器学习的新手,我试图使用 numpy 从头开始​​实现向量化线性回归。我尝试使用 y=x 测试实现。但是我的损失在增加,我无法理解为什么。如果有人能指出为什么会发生这种情况,那就太好了。提前致谢!

import numpy as np

class LinearRegressor(object):
    def __init__(self, num_features):
        self.num_features = num_features
        self.w = np.random.randn(num_features, 1).astype(np.float32)
        self.b = np.array(0.0).astype(np.float32)

    def forward(self, x):
        return np.dot(x, self.w) + self.b

    @staticmethod
    def loss(y_pred, y_true):
        l = np.average(np.power(y_pred - y_true, 2)) / 2
        return l

    def calculate_gradients(self, x, y_pred, y_true):
        self.dl_dw = np.dot(x.T, y_pred - y_true) / len(x)
        self.dl_db = np.mean(y_pred - y_true)

    def optimize(self, step_size):
        self.w -= step_size*self.dl_dw
        self.b -= step_size*self.dl_db

    def train(self, x, y, step_size=1.0):
        y_pred = self.forward(x)
        l = self.loss(y_pred=y_pred, y_true=y)
        self.calculate_gradients(x=x, y_pred=y_pred, y_true=y)
        self.optimize(step_size=step_size)
        return l

    def evaluate(self, x, y):
        return self.loss(self.forward(x), y_true)

check_reg = LinearRegressor(num_features=1)
x = np.array(list(range(1000))).reshape(-1, 1)
y = x
losses = []
for iteration in range(100):
    loss = check_reg.train(x=x,y=y, step_size=0.001)
    losses.append(loss)
    if iteration % 1 == 0:
        print("Iteration: {}".format(iteration))
        print(loss)

输出

Iteration: 0
612601.7859402705
Iteration: 1
67456013215.98818
Iteration: 2
7427849474110884.0
Iteration: 3
8.179099502901393e+20
Iteration: 4
9.006330707513148e+25
Iteration: 5
9.917228672922966e+30
Iteration: 6
1.0920254505132042e+36
Iteration: 7
1.2024725981084638e+41
Iteration: 8
1.324090295064888e+46
Iteration: 9
1.4580083421516024e+51
Iteration: 10
1.60547085025467e+56
Iteration: 11
1.7678478362285333e+61
Iteration: 12
1.946647415292399e+66
Iteration: 13
2.1435307416407376e+71
Iteration: 14
2.3603265498975516e+76
Iteration: 15
2.599049318486855e+81
Iteration: 16
nan
Iteration: 17
nan
Iteration: 18
nan
Iteration: 19
nan
Iteration: 20
nan
Iteration: 21
nan
Iteration: 22
nan
Iteration: 23
nan
Iteration: 24
nan
Iteration: 25
nan
Iteration: 26
nan
Iteration: 27
nan
Iteration: 28
nan
Iteration: 29
nan
Iteration: 30
nan
Iteration: 31
nan
Iteration: 32
nan
Iteration: 33
nan
Iteration: 34
nan
Iteration: 35
nan
Iteration: 36
nan
Iteration: 37
nan
Iteration: 38
nan
Iteration: 39
nan
Iteration: 40
nan
Iteration: 41
nan
Iteration: 42
nan
Iteration: 43
nan
Iteration: 44
nan
Iteration: 45
nan
Iteration: 46
nan
Iteration: 47
nan
Iteration: 48
nan
Iteration: 49
nan
Iteration: 50
nan
Iteration: 51
nan
Iteration: 52
nan
Iteration: 53
nan
Iteration: 54
nan
Iteration: 55
nan
Iteration: 56
nan
Iteration: 57
nan
Iteration: 58
nan
Iteration: 59
nan
Iteration: 60
nan
Iteration: 61
nan
Iteration: 62
nan
Iteration: 63
nan
Iteration: 64
nan
Iteration: 65
nan
Iteration: 66
nan
Iteration: 67
nan
Iteration: 68
nan
Iteration: 69
nan
Iteration: 70
nan
Iteration: 71
nan
Iteration: 72
nan
Iteration: 73
nan
Iteration: 74
nan
Iteration: 75
nan
Iteration: 76
nan
Iteration: 77
nan
Iteration: 78
nan
Iteration: 79
nan
Iteration: 80
nan
Iteration: 81
nan
Iteration: 82
nan
Iteration: 83
nan
Iteration: 84
nan
Iteration: 85
nan
Iteration: 86
nan
Iteration: 87
nan
Iteration: 88
nan
Iteration: 89
nan
Iteration: 90
nan
Iteration: 91
nan
Iteration: 92
nan
Iteration: 93
nan
Iteration: 94
nan
Iteration: 95
nan
Iteration: 96
nan
Iteration: 97
nan
Iteration: 98
nan
Iteration: 99
nan

【问题讨论】:

    标签: python numpy machine-learning linear-regression


    【解决方案1】:

    您的实施没有任何问题。您的步长太大而无法收敛。你在优化波峰附近跳跃到越来越高的错误。 为此编辑您的步长:

    loss = check_reg.train(x=x,y=y, step_size=0.000001)
    

    你会得到:

    Iteration: 0
    58305.102166924036
    Iteration: 1
    25952.192344178206
    Iteration: 2
    11551.585414406314
    Iteration: 3
    5141.729521746186
    Iteration: 4
    2288.6353484460747
    Iteration: 5
    1018.6952280352172
    Iteration: 6
    453.4320214875039
    Iteration: 7
    201.82728832044089
    Iteration: 8
    89.83519431606754
    Iteration: 9
    39.98665864625944
    Iteration: 10
    17.798416262435936
    Iteration: 11
    7.92229454258205
    Iteration: 12
    3.526272890501929
    Iteration: 13
    1.5696002444816197
    Iteration: 14
    0.6986516574778796
    Iteration: 15
    0.3109875219688626
    Iteration: 16
    0.13843156434074647
    Iteration: 17
    0.061616235257299326
    Iteration: 18
    0.027424318402401473
    Iteration: 19
    0.012205888201891543
    Iteration: 20
    0.005434012356344396
    Iteration: 21
    0.0024188644277583476
    Iteration: 22
    0.0010770380211645404
    Iteration: 23
    0.0004796730257022216
    Iteration: 24
    0.00021339295719587025
    Iteration: 25
    9.499628306355218e-05
    Iteration: 26
    4.244764386691682e-05
    Iteration: 27
    1.8965112443214162e-05
    Iteration: 28
    8.56069334821767e-06
    Iteration: 29
    3.848135476439999e-06
    Iteration: 30
    1.7367004907528985e-06
    Iteration: 31
    8.07976330965736e-07
    Iteration: 32
    4.0167090640020525e-07
    Iteration: 33
    2.253979336583221e-07
    Iteration: 34
    1.5365746125585947e-07
    Iteration: 35
    1.2480275459766612e-07
    Iteration: 36
    1.1147859663321005e-07
    Iteration: 37
    1.0288427880059631e-07
    Iteration: 38
    1.0036079530613815e-07
    Iteration: 39
    9.901975516098116e-08
    Iteration: 40
    9.901971962009025e-08
    Iteration: 41
    9.901968407922984e-08
    Iteration: 42
    9.901964853839991e-08
    Iteration: 43
    9.901961299760048e-08
    Iteration: 44
    9.901957745683155e-08
    Iteration: 45
    9.90195419160931e-08
    Iteration: 46
    9.901950637538515e-08
    Iteration: 47
    9.90194708347077e-08
    Iteration: 48
    9.901943529406073e-08
    Iteration: 49
    9.901939975344426e-08
    Iteration: 50
    9.901936421285829e-08
    Iteration: 51
    9.90193286723028e-08
    Iteration: 52
    9.901929313177781e-08
    Iteration: 53
    9.901925759128331e-08
    Iteration: 54
    9.901922205081931e-08
    Iteration: 55
    9.90191865103858e-08
    Iteration: 56
    9.901915096998278e-08
    Iteration: 57
    9.901911542961026e-08
    Iteration: 58
    9.901907988926822e-08
    Iteration: 59
    9.901904434895669e-08
    Iteration: 60
    9.901900880867564e-08
    Iteration: 61
    9.901897326842509e-08
    Iteration: 62
    9.901893772820503e-08
    Iteration: 63
    9.901890218801546e-08
    Iteration: 64
    9.901886664785639e-08
    Iteration: 65
    9.901883110772781e-08
    Iteration: 66
    9.901879556762973e-08
    Iteration: 67
    9.901876002756213e-08
    Iteration: 68
    9.901872448752503e-08
    Iteration: 69
    9.901868894751843e-08
    Iteration: 70
    9.901865340754231e-08
    Iteration: 71
    9.901861786759669e-08
    Iteration: 72
    9.901858232768157e-08
    Iteration: 73
    9.901854678779693e-08
    Iteration: 74
    9.901851124794279e-08
    Iteration: 75
    9.901847570811914e-08
    Iteration: 76
    9.901844016832599e-08
    Iteration: 77
    9.901840462856333e-08
    Iteration: 78
    9.901836908883116e-08
    Iteration: 79
    9.901833354912948e-08
    Iteration: 80
    9.90182980094583e-08
    Iteration: 81
    9.901826246981762e-08
    Iteration: 82
    9.901822693020742e-08
    Iteration: 83
    9.901819139062772e-08
    Iteration: 84
    9.901815585107851e-08
    Iteration: 85
    9.90181203115598e-08
    Iteration: 86
    9.901808477207157e-08
    Iteration: 87
    9.901804923261384e-08
    Iteration: 88
    9.90180136931866e-08
    Iteration: 89
    9.901797815378986e-08
    Iteration: 90
    9.901794261442361e-08
    Iteration: 91
    9.901790707508786e-08
    Iteration: 92
    9.901787153578259e-08
    Iteration: 93
    9.901783599650782e-08
    Iteration: 94
    9.901780045726355e-08
    Iteration: 95
    9.901776491804976e-08
    Iteration: 96
    9.901772937886647e-08
    Iteration: 97
    9.901769383971367e-08
    Iteration: 98
    9.901765830059137e-08
    Iteration: 99
    9.901762276149956e-08
    

    希望对你有帮助!

    【讨论】:

    • 感谢您的回答。那么如何计算步长呢?有没有我可以使用的经验法则?
    • 有很多关于这方面的文献和许多您可以使用的技术,这里有一些:onmyphd.com/?p=gradient.descent
    • 一个非常简单的方法是将步长作为超参数来调整并使用它们中的一个很好的范围来训练模型。然后,您使用交叉验证来验证哪个模型学得最好。