【问题标题】:How to plot SciKit-Learn linear regression graph如何绘制 SciKit-Learn 线性回归图
【发布时间】:2020-11-01 15:01:36
【问题描述】:

我是 SciKit-Learn 的新手,我一直在研究 kaggle 上的回归问题(国王县 csv)。我一直在训练一个回归模型来预测房子的价格,我想绘制图表,但我不知道该怎么做。我正在使用python 3.6。任何意见或建议将不胜感激。

#importing numpy and pandas, seaborn

import numpy as np #linear algebra
import pandas as pd #datapreprocessing, CSV file I/O
import seaborn as sns #for plotting graphs
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt

data = pd.read_csv('kc_house_data.csv')
data = data.drop('date',axis=1)
data = data.drop('id',axis=1)

X = data
Y = X['price'].values
X = X.drop('price', axis = 1).values

X_train, X_test, Y_train, Y_test = train_test_split (X, Y, test_size = 0.30, random_state=21)


reg = LinearRegression()
kfold = KFold(n_splits=15, random_state=21)
cv_results = cross_val_score(reg, X_train, Y_train, cv=kfold, scoring='r2')

print(cv_results)

round(np.mean(cv_results)*100, 2)

【问题讨论】:

    标签: python plot scikit-learn linear-regression


    【解决方案1】:

    【讨论】:

      【解决方案2】:

      您可以使用matplotlib 进行绘图

      import matplotlib.pyplot as plt
      plt.figure(figsize=(16, 9))
      plt.plot(cv_results)
          
      plt.show()
      
      

      您可以使用多种类型的图,例如简单的线图或散点图。

      plt.barh(x, y) # for bar graph
      plt.plot(x,y)  # for line graph
      plt.scatter(x,y) # for scatter graph
      

      【讨论】:

        【解决方案3】:

        Seaborn 是一个非常有用的可视化库。如此之多,以至于您可以使用“seaborn.regplot”直接绘制数据和回归模型拟合线。它直接接受预测变量和响应变量,并吐出数据点和最佳拟合线的图。这是如何使用它的链接:

        https://seaborn.pydata.org/generated/seaborn.regplot.html

        【讨论】:

          【解决方案4】:

          我也在 kaggle 上做过同样的比赛。 对于回归,我会选择散点图:

          import matplotlib as plt
          plt.plot(x,y)
          

          至于该特定比赛的可视化,我将使用以下代码:

          # visualising some more outliers in the data values
          fig, axs = plt.subplots(ncols=2, nrows=0, figsize=(12, 120))
          plt.subplots_adjust(right=2)
          plt.subplots_adjust(top=2)
          sns.color_palette("husl", 8)
          for i, feature in enumerate(list(train[numeric]), 1):
          if(feature=='MiscVal'):
              break
          plt.subplot(len(list(numeric)), 3, i)
          sns.scatterplot(x=feature, y='SalePrice', hue='SalePrice', palette='Blues', data=train)
              
          plt.xlabel('{}'.format(feature), size=15,labelpad=12.5)
          plt.ylabel('SalePrice', size=15, labelpad=12.5)
          
          for j in range(2):
              plt.tick_params(axis='x', labelsize=12)
              plt.tick_params(axis='y', labelsize=12)
          
          plt.legend(loc='best', prop={'size': 10})
              
          plt.show()
          

          如果你想看看,我实际上已经在我的 GitHub 上上传了该竞赛的完整代码;)(我目前在该竞赛中排名前 14%)。

          【讨论】:

          • 谢谢兄弟,只是为了澄清一下,功能 == 'MiscVal' 是什么意思
          • 没问题。 “MiscVal”是 train.csv 文件的列之一。根据 kaggle 的描述,它基本上是分配给其他类别中未涵盖的特征的值(地块面积、街道、建筑质量、屋顶材料……)
          猜你喜欢
          • 2019-03-08
          • 2018-01-18
          • 2019-05-14
          • 2017-12-25
          • 2018-07-31
          • 2016-05-10
          • 2021-04-01
          • 2016-10-23
          • 2017-03-26
          相关资源
          最近更新 更多