【问题标题】:How prediction and score works in scikit learnscikit learn 中的预测和评分如何工作
【发布时间】:2019-01-12 10:43:49
【问题描述】:

我正在尝试使用线性回归基于一组输入来预测输出,如下所示:

import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split

X = [[1, 1, 1, 1],
     [1, 1, 1, 1],
     [1, 2, 1, 1],
     [1, 3, 1, 1],
     [1, 4, 1, 1],
     [1, 2, 1, 1],
     [1, 3, 1, 1],
     [2, 4, 1, 1],
     [1, 1, 1, 1],
     [2, 1, 1, 1],
     [2, 4, 1, 1],
     [1, 5, 1, 1],
     [1, 1, 1, 1],
     [1, 1, 1, 1]]
y = [
    [1],
    [1],
    [1],
    [3],
    [2],
    [1],
    [3],
    [2],
    [1],
    [1],
    [2],
    [1],
    [1],
    [1],
   ]


# Split X and y into X_
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=1)

regression_model = LinearRegression()
regression_model.fit(X_train, y_train)

print(regression_model.score(X_test, y_test)) # -1.1817143658810325
print(regression_model.predict([[1, 1, 1, 1]]) # [[0.9694444444444441]]

我已将 X 值作为输入传递并期望 y 作为输出

它将分数显示为负值,预测输出为 [[0.9694444444444441]],我预计为 1。

我该如何解决这个问题?

【问题讨论】:

  • 你读过documentation吗?为什么您认为负分和预测输出不正确?

标签: python machine-learning scikit-learn linear-regression


【解决方案1】:

线性回归尝试使用最优超平面最小化均方误差。大多数数据都不是完全线性的(包括你的),所以预测不会是完美的。但是,考虑到线性约束,它们将具有尽可能低的误差。在您的示例中,0.97 和 1.00 之间没有太大区别。

在较少的维度中考虑以下线性回归,以使可视化更容易。回归所做的只是选择最适合数据的线。这并不意味着它贯穿每一点。当您使用那条线进行预测时,它会偏离一点点。

负分(直接来自文档)仅表示模型的性能比您只预测数据的平均值时更差。模型可以任意表现不佳。在您的情况下,由于线性回归能够学习这样一个常数模型,这表明对训练集过度拟合(可能是由于样本量小)。如果你对你的火车数据进行评分,你应该得到一个非否定的答案,而且可能是肯定的。

更仔细地检查您的模型,您会注意到,由于大类不平衡(您的 1 几乎是其他所有东西加起来的两倍),因此预测值为 1 的任何东西都相对接近。 2 有点糟糕,3 有一个可怕的预测。线性模型很难实现从 1 和 2 到 3 的巨大跳跃,因为只有几个点卡在其余点云的中间。

【讨论】:

  • 感谢@Hans Musgrave 的详细解释。你认为有没有其他模型可以适合我的数据,或者只是增加数据集的大小来解决这个问题。
  • 这最好作为另一个问题来处理,并且在一定程度上取决于您的目标。这里有两个重要的思想是interpolationregression。前者与您的数据完全匹配(因此将具有完美的训练准确性),而后者试图最小化某种错误(例如您尝试的线性回归)。虽然插值可以解决您的问题,但它通常缺乏泛化的能力,考虑到使用训练/测试拆分,这似乎是您关心的事情。大多数机器学习模型并不完美,所以我不会太担心。
  • 增加数据集的大小至少可以解决负分问题。对于这个问题,我预计具有足够数据的线性回归的得分约为 0.4,这样您就不会过度拟合训练集。
  • 感谢@Hans Musgrave 的建议。我尝试使用 GaussianNB,除了一两个得分为 0.5 的记录外,几乎所有记录的得分似乎都是 1.0。如果您更喜欢任何其他算法,请告诉我,我会试一试。
  • 数据这么少,也没有域信息,我真的没有最喜欢的算法。不过需要指出的一件事是,由于重复的 [1,2,1,1] 具有不同的 y 值,您将永远无法达到完美的准确度。此外,所有预测能力都存在于第二坐标中。您可以在不降低精度的情况下摆脱第 1、第 3 和第 4 个坐标。
猜你喜欢
  • 2016-08-07
  • 2015-02-19
  • 2021-01-11
  • 2015-01-12
  • 2018-05-14
  • 2016-04-03
  • 2015-03-01
  • 2014-01-17
相关资源
最近更新 更多