【问题标题】:How to add a line of best fit to scatter plot如何添加一条最适合散点图的线
【发布时间】:2016-09-11 01:50:39
【问题描述】:

我目前正在使用 Pandas 和 matplotlib 来执行一些数据可视化,我想在我的散点图中添加一条最适合的线。

这是我的代码:

import matplotlib
import matplotlib.pyplot as plt
import pandas as panda
import numpy as np

def PCA_scatter(filename):

   matplotlib.style.use('ggplot')

   data = panda.read_csv(filename)
   data_reduced = data[['2005', '2015']]

   data_reduced.plot(kind='scatter', x='2005', y='2015')
   plt.show()

PCA_scatter('file.csv')

我该怎么做?

【问题讨论】:

标签: python numpy pandas matplotlib plot


【解决方案1】:

您可以使用np.polyfit()np.poly1d()。使用相同的x 值估计一次多项式,并添加到由.scatter() 绘图创建的ax 对象。举个例子:

import numpy as np

     2005   2015
0   18882  21979
1    1161   1044
2     482    558
3    2105   2471
4     427   1467
5    2688   2964
6    1806   1865
7     711    738
8     928   1096
9    1084   1309
10    854    901
11    827   1210
12   5034   6253

估计一次多项式:

z = np.polyfit(x=df.loc[:, 2005], y=df.loc[:, 2015], deg=1)
p = np.poly1d(z)
df['trendline'] = p(df.loc[:, 2005])

     2005   2015     trendline
0   18882  21979  21989.829486
1    1161   1044   1418.214712
2     482    558    629.990208
3    2105   2471   2514.067336
4     427   1467    566.142863
5    2688   2964   3190.849200
6    1806   1865   2166.969948
7     711    738    895.827339
8     928   1096   1147.734139
9    1084   1309   1328.828428
10    854    901   1061.830437
11    827   1210   1030.487195
12   5034   6253   5914.228708

和情节:

ax = df.plot.scatter(x=2005, y=2015)
df.set_index(2005, inplace=True)
df.trendline.sort_index(ascending=False).plot(ax=ax)
plt.gca().invert_xaxis()

获得:

还提供了线方程:

'y={0:.2f} x + {1:.2f}'.format(z[0],z[1])

y=1.16 x + 70.46

【讨论】:

  • trendline.plot(ax=ax) 行给了我一个无效的语法错误
  • z = np.polyfit(x=data_reduced[['2005']], y=data_reduced[['2015']], 1) 行给了我一个“位置参数跟随关键字参数”错误
  • 对不起,degree 需要在=1 之前添加deg,见更新。
  • TypeError: 对于行 z = np.polyfit(x=data_reduced[['2005']], y=data_reduced[['2015']], deg=1),x 的预期一维向量。这是我的数据或代码的问题吗?
  • 需要使用.loc[],所以单列变成pd.Series。使用[[]] 选择会保留一列作为DataFrame,因此会出现维度警告。更新,同样适用于下一行。不好意思,时间不早了……
【解决方案2】:

另一个选项(使用np.linalg.lstsq):

# generate some fake data
N = 50
x = np.random.randn(N, 1)
y = x*2.2 + np.random.randn(N, 1)*0.4 - 1.8
plt.axhline(0, color='r', zorder=-1)
plt.axvline(0, color='r', zorder=-1)
plt.scatter(x, y)

# fit least-squares with an intercept
w = np.linalg.lstsq(np.hstack((x, np.ones((N,1)))), y)[0]
xx = np.linspace(*plt.gca().get_xlim()).T

# plot best-fit line
plt.plot(xx, w[0]*xx + w[1], '-k')

【讨论】:

    【解决方案3】:

    您可以使用Seaborn 一口气完成所有工作和情节。

    import pandas as pd
    import seaborn as sns
    data_reduced= pd.read_csv('fake.txt',sep='\s+')
    sns.regplot(data_reduced['2005'],data_reduced['2015'])
    

    【讨论】:

    • 但是我想用matplotlib! :(
    • 这个解决方案多么简单,真是太棒了!非常感谢!
    • 如果您想在循环和创建多个图表时一次查看一个图表,您仍然需要 matplotlib 的 plt.show()
    【解决方案4】:

    这涵盖了plotly 方法

    #load the libraries
    
    import pandas as pd
    import numpy as np
    import plotly.express as px
    import plotly.graph_objects as go
    
    # create the data
    N = 50
    x = pd.Series(np.random.randn(N))
    y = x*2.2 - 1.8
    
    # plot the data as a scatter plot
    fig = px.scatter(x=x, y=y) 
    
    # fit a linear model 
    m, c = fit_line(x = x, 
                    y = y)
    
    # add the linear fit on top
    fig.add_trace(
        go.Scatter(
            x=x,
            y=m*x + c,
            mode="lines",
            line=go.scatter.Line(color="red"),
            showlegend=False)
    )
    # optionally you can show the slop and the intercept 
    mid_point = x.mean()
    
    fig.update_layout(
        showlegend=False,
        annotations=[
            go.layout.Annotation(
                x=mid_point,
                y=m*mid_point + c,
                xref="x",
                yref="y",
                text=str(round(m, 2))+'x+'+str(round(c, 2)) ,
            )
        ]
    )
    fig.show()
    

    fit_line 在哪里

    def fit_line(x, y):
        # given one dimensional x and y vectors - return x and y for fitting a line on top of the regression
        # inspired by the numpy manual - https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.lstsq.html 
        x = x.to_numpy() # convert into numpy arrays
        y = y.to_numpy() # convert into numpy arrays
    
        A = np.vstack([x, np.ones(len(x))]).T # sent the design matrix using the intercepts
        m, c = np.linalg.lstsq(A, y, rcond=None)[0]
    
        return m, c
    

    【讨论】:

      【解决方案5】:

      上面的最佳答案是使用 seaborn。 补充一点,如果你用循环创建许多图,你仍然可以使用 matplotlib

          import pandas as pd
          import seaborn as sns
          import matplotlib.pyplot as plt
      
          data_reduced= pd.read_csv('fake.txt',sep='\s+')
          for x in data_reduced.columns:
              sns.regplot(data_reduced[x],data_reduced['2015'])
              plt.show()
      

      plt.show() 将暂停执行,以便您一次查看一个图

      【讨论】:

        【解决方案6】:

        只是添加到(更新罗伯特卡尔霍恩的答案)。如果您不指定 x,y,您现在将在新版本的 pandas 上收到未来警告。

        FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
        

        所以,如下。

        import pandas as pd
        import seaborn as sns
        data_reduced= pd.read_csv('fake.txt',sep='\s+')
        sns.regplot(x=data_reduced['2005'],y=data_reduced['2015']) 
        

        【讨论】:

          猜你喜欢
          • 2020-02-26
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 2019-08-15
          • 1970-01-01
          • 1970-01-01
          相关资源
          最近更新 更多