【发布时间】:2017-09-08 21:30:07
【问题描述】:
我已经能够成功地使用 SVR 来预测具有一个数据条目的数据集上的值。但是,我的数据集每个“行”或“条目”或任何您想要调用的条目都有 47 个条目。我已经上传了我的数据集 csv,并在我的代码中注释掉了 get_data 函数中的其他 46 个条目。
所有 47 个数据条目都是相对的,并影响 x,即球员的薪水。我正在尝试仅使用已知该球员薪水之前该球员可用的统计数据来预测该球员的未来薪水。但是,正如我所提到的,很多统计数据都定义了薪水,目前我只能对 1 个统计数据条目进行预测。
import csv
import numpy as np
from sklearn.svm import SVR
import matplotlib.pyplot as plt
salary = []
stats = []
def get_data(filename):
with open(filename, 'r', encoding='utf8', errors='ignore') as csvfile:
csvFileReader = csv.reader(csvfile)
for row in csvFileReader:
# stats.append(float(row[4])) #
# stats.append(int(row[5])) #
salary.append(float(row[6]))
# stats.append(int(row[8])) #
# stats.append(int(row[9])) #
# stats.append(int(row[10])) #
stats.append(int(row[11])) #
# stats.append(int(row[12])) #
# stats.append(int(row[13])) #
# stats.append(float(row[14])) #
# stats.append(int(row[15])) #
# stats.append(int(row[16])) #
# stats.append(int(row[17])) #
# stats.append(int(row[18])) #
# stats.append(int(row[19])) #
# stats.append(int(row[20])) #
# stats.append(int(row[21])) #
# stats.append(int(row[22])) #
# stats.append(int(row[23])) #
# stats.append(int(row[24])) #
# stats.append(float(row[25])) #
# stats.append(int(row[26])) #
# stats.append(int(row[27])) #
# stats.append(int(row[28])) #
# stats.append(int(row[29])) #
# stats.append(int(row[30])) #
# stats.append(int(row[31])) #
# stats.append(int(row[32])) #
# stats.append(int(row[33])) #
# stats.append(int(row[34])) #
# stats.append(int(row[35])) #
# stats.append(float(row[36])) #
# stats.append(int(row[37])) #
# stats.append(int(row[38])) #
# stats.append(int(row[39])) #
# stats.append(int(row[40])) #
# stats.append(int(row[41])) #
# stats.append(int(row[42])) #
# stats.append(int(row[43])) #
# stats.append(int(row[44])) #
# stats.append(int(row[45])) #
# stats.append(int(row[46])) #
# stats.append(float(row[47])) #
# stats.append(int(row[48])) #
# stats.append(int(row[49])) #
# stats.append(int(row[50])) #
# stats.append(int(row[51])) #
# stats.append(int(row[52])) #
return
get_data('dataset.csv')
def predict_salary(stats, salary, x):
stats = np.reshape(stats,(len(salary), int(len(stats)/len(salary))))
svr_lin = SVR(kernel='linear', C=1e3, epsilon=0.2, cache_size=7000)
svr_rbf = SVR(kernel= 'rbf', C=1e3, gamma=0.1, cache_size=7000)
svr_poly = SVR(kernel='poly', C=1e3, degree=2, cache_size=7000)
svr_lin.fit(stats, salary)
svr_rbf.fit(stats, salary)
svr_poly.fit(stats, salary)
plt.scatter(stats, salary, color='black', label='Data')
plt.plot(stats, svr_lin.predict(stats), color='green', label='Linear model')
plt.plot(stats, svr_rbf.predict(stats), color='red', label='RBF model')
plt.plot(stats, svr_poly.predict(stats), color='blue', label='Polynomial model')
plt.xlabel('Stats')
plt.ylabel('Salary')
plt.title('Support Vector Regression')
plt.legend()
plt.show()
return svr_lin.predict(x)[0], svr_rbf.predict(x)[0], svr_poly.predict(x)[0]
projected_salary = predict_salary(stats, salary, 1)
print (projected_salary)
这里是 dataset.csv,我只包含了 10 行,但我拥有的数据多达 200 行:
N/A,N/A,player 1,team,3,26,1350000,508500,22,31,32,8,361,3,0.217,0,0,0,0,25,33,48,11,390,13,0.256,0,0,0,0,9,18,22,1,225,4,0.215,0,0,0,0,22,27,37,8,313,9,0.192,0,0,0,0,0
N/A,N/A,player 2,team,3,27,805000,508500,15,26,17,4,176,1,0.242,0,0,0,0,0,0,0,0,2,0,0,0,0,0,0,1,1,2,0,13,0,0.231,0,0,0,0,10,10,17,1,168,1,0.201,0,0,0,0,0
N/A,N/A,player 3,team,3,25,2625000,508500,25,17,69,3,460,58,0.26,0,0,0,0,15,28,56,4,454,57,0.226,0,0,0,0,39,48,72,6,611,56,0.25,0,0,0,0,2,1,9,0,22,13,0.368,2,0,0,0,0
N/A,N/A,player 4,team,3,26,3575000,508500,65,81,73,30,601,6,0.243,0,0,0,0,37,46,44,11,497,13,0.258,0,0,0,0,29,36,47,10,411,4,0.221,0,0,0,1,25,36,41,8,335,5,0.265,0,0,0,0,0
N/A,N/A,player 5,team,3,28,1950000,508500,23,34,45,7,324,4,0.255,0,0,0,0,35,45,56,2,509,8,0.28,1,0,0,0,32,29,68,4,492,12,0.281,0,0,0,0,5,14,15,0,144,1,0.25,0,0,0,0,0
N/A,N/A,player 6,team,2.5,30,700000,508500,3,0,7,0,141,0,0.174,0,0,0,0,28,49,38,11,355,0,0.234,0,0,0,0,18,28,22,9,275,0,0.207,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
N/A,N/A,player 7,team,2.5,26,2550000,508500,31,39,67,6,622,17,0.294,1,0,0,0,25,35,57,1,452,19,0.272,0,0,0,0,3,4,13,1,125,1,0.237,0,0,0,0,5,10,17,0,131,0,0.289,0,0,0,0,0
N/A,N/A,player 8,team,3,28,938000,508500,15,28,21,6,166,4,0.284,0,0,0,0,8,10,13,2,113,0,0.146,0,0,0,0,3,4,8,0,79,1,0.213,0,0,0,0,11,19,16,4,197,0,0.189,0,0,0,0,0
N/A,N/A,player 9,team,3,24,2300000,508500,40,49,52,5,466,21,0.277,0,0,0,0,36,43,59,4,552,16,0.227,0,0,0,0,27,26,34,6,332,8,0.261,0,0,0,0,5,5,5,0,61,2,0.291,0,0,0,0,0
N/A,N/A,player 10,team,3,27,3025000,508500,63,70,57,24,548,0,0.245,0,0,0,0,30,31,30,10,234,0,0.304,0,0,0,0,57,76,74,24,478,8,0.312,0,0,0,0,23,17,32,5,213,2,0.263,0,0,0,0,0
我花了几天时间才使用 47 个条目中的 1 个来完成这项工作,还有几个人试图弄清楚如何让它为每个玩家分析整个系列。我是python的初学者,没有统计背景,所以我现在完全迷路了!感谢任何帮助或指导,谢谢!
【问题讨论】:
-
顺便说一句,数据集的 200 行远非“巨大”。如今,庞大的数据集以 TB 级计算。
标签: python scikit-learn regression