【问题标题】:Support Vector Machine Python 3.5.2支持向量机 Python 3.5.2
【发布时间】:2017-05-30 02:31:27
【问题描述】:

SVM 上搜索一些教程时,我在网上找到了 - Support Vector Machine _ Illustration - 下面的代码,但是它产生了一个 weird 图表。调试完代码,不知道是不是在Date列表上,准确的说:

dates.append(int(row[0].split('-')[0]))

从我这边来看是静态的(即 2016 年)或者如果还有其他情况,尽管我在代码中没有看到任何异常。

编辑

这个推论来自语法:

plt.scatter(dates, prices, color ='black', label ='Data'); 
plt.show()

事实上,屈服于垂直线,而

dates.append(int(row[0].split('-')[0]))

假设,如链接中所述,也反映在代码中,将每个日期 YYYY-MM-DD 转换为不同的整数值

编辑(2)

dates.append(md.datestr2num(row[0]))代替

函数get_data(filename)中的dates.append(int(row[0].split('-')[0]))确实有帮助!

import csv
import numpy as np
from sklearn.svm import SVR
import matplotlib.pyplot as plt

dates = []
prices = []

def get_data(filename):
    with open(filename, 'r') as csvfile:
        csvFileReader = csv.reader(csvfile)
        next(csvFileReader)
        for row in csvFileReader:
            dates.append(int(row[0].split('-')[0]))
            prices.append(float(row[6]))  # from 1 i.e from Opening to closing price

    return

def predict_prices(dates,prices,x):
    dates = np.reshape(dates,(len(dates),1))
    svr_lin = SVR(kernel = 'linear', C = 1e3)
    svr_poly = SVR(kernel = 'poly', C = 1e3, degree = 2)
    svr_rbf = SVR(kernel = 'rbf',  C = 1e3, gamma = 0.1)  

    svr_lin.fit(dates,prices)
    svr_poly.fit(dates,prices)
    svr_rbf.fit(dates,prices)

    plt.scatter(dates, prices, color ='black', label ='Data')
    plt.plot(dates, svr_rbf.predict(dates), color ='red', label = 'RBF model')
    plt.plot(dates, svr_rbf.predict(dates), color ='green', label = 'Linear model')
    plt.plot(dates, svr_rbf.predict(dates), color ='blue', label = 'Polynomial model')
    plt.xlabel('Date')
    plt.ylabel('Price')
    plt.title('Support Vector Regression')
    plt.legend
    plt.show()
    return svr_rbf.predict(x)[0], svr_lin.predict(x)[0], svr_poly.predict(x)[0]

get_data('C:/local/ACA.csv')
predict_prices(dates, prices, 29)

提前致谢

【问题讨论】:

  • 还有什么奇怪的?您正在尝试从...预测 2016 年的某个数字(价格),因此无能为力。这样的数据根本没有意义。
  • weird 是什么,您有 3 个不同的模型试图预测过去 30 天的价格,其中一个模型以一条垂直线结束。
  • @lejlot: plt.scatter(dates, prices, color ='black', label ='Data'); plt.show() 实际上产生了垂直线。正如帖子中提到的,这件事似乎来自dates.append(int(row[0].split('-')[0]))。实际上,每个日期YYYY-MM-DD 被转换为静态年份YYYY, 而不是每个不同日期的不同integer 值的事实可能是原因。我不知何故缺少的是为什么从语法.split('-') 中删除“-”似乎不起作用。
  • 将额外信息添加到您的问题中。在评论格式中很难阅读。
  • @hpaulj:我已经相应地编辑了帖子。如前所述,在我看来问题出在.split('-') 函数上。为了说明我的观点,csv 中的所有日期(例如 2016-12-28、...2016-12-30)都转换为 2016。

标签: python numpy matplotlib machine-learning svm


【解决方案1】:

get_data 创建两个列表,datesprices

np.array(dates)np.array(prices) 产生什么?形状和数据类型?由于您的绘图仅显示一个日期,我们需要查看该数组的值范围。

我编辑了您的问题,试图使函数定义正确。确保我做对了。

csv 中的日期列是什么样的?

看起来像您的 dates 解析:

In [25]: txt = '2016-02-20'

In [26]: txt.split('-')
Out[26]: ['2016', '02', '20']

In [27]: int(txt.split('-')[0])
Out[27]: 2016

所以你只是抓住了一年。这将解释

处的垂直散点图
In [29]: 0.010+2.01599e3
Out[29]: 2016.0

我认为这将是一个更好的日期转换 - 到 np.datetime64 dtype。

In [28]: np.array([txt], dtype='datetime64[D]')
Out[28]: array(['2016-02-20'], dtype='datetime64[D]')

【讨论】:

  • 感谢您确认一个想法。因此我会更新get_data(filename)
【解决方案2】:

我一直在使用来自多个示例(Siraj、Chaitjo、Jaihad 等)的 SVM 代码...并发现日期需要采用 DD-MM-YYYY 格式...所以使用的数据是日期...不是年份日期(如 dark.vapor 所述)。

而且数据只能保存 30 天……如这段代码中所见:

“predict_prices(日期,价格,29)”

否则,使用具有多个月份的数据文件(具有重复的天数...例如 1 月 15 日和 2 月 15 日)...我每天绘制多个价格,而不是每天仅绘制一天价格。

Edit2:我尝试改变数据集,发现数据行可以超过 29 行……只要日期只是一个整数序列。我长达 85 天(行)......他们都绘制了。所以我对上面预测代码中的“29”做了什么有点困惑?

如果能够使用包含多个月的较大数据文件......并选择我想要测试的日期范围......但现在这超出了我的编码技能。

我只是一个新手编码,所以我希望这是准确的,因为这似乎对我有用,使用 DD-MM-YYYY 格式可以正常工作并给我一个很好的干净情节。

希望这会有所帮助, 罗伯特

编辑:我刚刚找到了一篇描述这段代码的好文章……它确认了“日”解析为 DD-MM-YYYY 格式……

https://github.com/mKausthub/stock-er

dates.append(int(row[0].split('-')[0])) "获取 月份中的某天,因为日期的格式为 [date]-[month]-[year],因此索引为零。"

【讨论】: