【问题标题】:Fit one-dimensional data with scikit-learn to predict line用 scikit-learn 拟合一维数据来预测线
【发布时间】:2017-08-13 15:50:00
【问题描述】:

我用 scikit-learn 编写了代码来为一维玩具数据构建 SVR 预测模型,然后用 matplotlib 绘制它。

蓝线是真实数据。具有线性内核的模型符合一条不错的线,但对于 2 级内核,预测不是我所期望的。我想要一个模型来预测蓝线的值略低于橙色线的预测值。我画了一条黑线来形象化我的想法。

  1. 为什么会这样?数据似乎是 2 次多项式的一个很好的候选者。黑色趋势线跟随真实数据,然后在右边很晚地减少,如果我只看这个,应该比绿线提供的拟合更好阴谋。不应该根据数据找到具有 2 次多项式的模型吗?它也会在靠近蓝线的 X = 0 处很好地弯曲,而不是在该处具有更高估计 y 值的曲率。

  2. 如何实现我想要的模型?

下面的代码是完整且独立的,运行它得到上面的图(减去画黑线)

# some data
y = [0, 3642, 6414, 9844, 13210, 16072, 18868, 22275, 25551, 28949, 31680, 34412, 37290, 39858, 42557, 
    45094, 47354, 49547, 51874, 54534, 55987, 55987, 58377, 60767, 63109, 65060, 66865, 68540, 70328, 
    72035, 73905, 75791, 77873, 79791, 81775, 83726]
X = range(0, len(y))
X_longer = range(0, len(y)*2)

# train models
from sklearn.svm import SVR
import numpy as np
clf_1 = SVR(kernel='poly', C=1e3, degree=1)
clf_2 = SVR(kernel='poly', C=1e3, degree=2)

clf_1.fit(np.array(X).reshape(-1, 1), y)
clf_2.fit(np.array(X).reshape(-1, 1), y)

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt

# plot real data
plt.plot(X, y, linewidth=8.0, label='true data')

predicted_1_y = []
predicted_2_y = []

# predict data points based on models
for i in X_longer:
    predicted_1_y.append(clf_1.predict(np.array([i]).reshape(-1, 1)))
    predicted_2_y.append(clf_2.predict(np.array([i]).reshape(-1, 1)))

# plot model predictions
plt.plot(X_longer, predicted_1_y, linewidth=6.0, ls=":", label='model, degree 1')
plt.plot(X_longer, predicted_2_y, linewidth=6.0, ls=":", label='model, degree 2')

plt.legend(loc='upper left')
plt.show()

【问题讨论】:

    标签: python matplotlib scikit-learn regression


    【解决方案1】:

    发生这种情况是因为线性和二次特征最终总是会向上或向下增长。您需要像平方根或对数这样的运算来获取所需的衰减特征。

    一种简单的方法是在拟合之前转换输入信号。例如,假设一个平方根趋势:

    X = np.array(X)[:,None]**2
    clf = SVR(kernel='linear').fit(X, y) 
    

    对于更一般的用例,如果你真的不知道你想要的趋势,或者不想假设这样的特定转换,你可以尝试像 Eureqa 这样的回归工具来计算最佳转换和数学模型可能。

    【讨论】:

    • 只是说它最终总会上升或下降并不能真正解释任何事情。他想要的趋势线可能是一个二次趋势,直到向右一定距离才会下降。
    • @BrenBarn 我想这取决于他最终想要什么。如果没有像这样的更严格的约束,很难控制模型如何推断或超出数据的范围。
    • @BrenBarn 写的也是我的假设。上面的代码没有运行,TypeError: list indices must be integers, not tuple.
    猜你喜欢
    • 2015-10-25
    • 2023-03-09
    • 1970-01-01
    • 2017-04-17
    • 2016-07-24
    • 2016-04-22
    • 2018-05-26
    • 2014-05-25
    • 2015-10-12
    相关资源
    最近更新 更多