【问题标题】:Plotly: How to retrieve regression results using plotly express?Plotly:如何使用 plotly express 检索回归结果?
【发布时间】:2021-02-10 23:10:02
【问题描述】:

您可以使用 plotly express / px.scatter 轻松绘制回归线,并使用 px.get_trendline_results(fig).iloc[0]["px_fit_results"].params[1] 检索回归结果,例如 beta。但是如何检索系数的其他参数,例如 R-squared 或 p-vales?

剧情:

代码:

# imports
import plotly.express as px
import pandas as pd
import numpy as np

# data
np.random.seed(123)
numdays=20
X = (np.random.randint(low=-20, high=20, size=numdays).cumsum()+100).tolist()
Y = (np.random.randint(low=-20, high=20, size=numdays).cumsum()+100).tolist()
df = pd.DataFrame({'X': X, 'Y':Y})

# figure using px.scatter
fig = px.scatter(df, x="X", y="Y", trendline="ols", template = 'plotly_dark')

fig.show()

【问题讨论】:

    标签: python plotly regression


    【解决方案1】:

    答案:

    model = px.get_trendline_results(fig)
    results = model.iloc[0]["px_fit_results"]
    alpha = params[0]
    beta = .params[1]
    p_beta = .pvalues[1]
    r_squared = .rsquared
    

    详情:

    所有回归结果均可通过以下方式获得:

    px.get_trendline_results(fig)
    

    它在运行时会返回一个看起来有点神秘的熊猫数据框:

                                          px_fit_results
    0  <statsmodels.regression.linear_model.Regressio...
    

    px_fit_results 下的元素是 statsmodels.regression.linear_model.RegressionResultsWrapper 类型的对象,它是 statsmodels 的包装器。

    所以如果我们通过设置来简化一下:

    models = px.get_trendline_results(fig)
    

    还有:

    results =  model.iloc[0]["px_fit_results"]
    

    然后我们可以使用以下方法检查该对象中可用的内容:

    dir(results)
    

    并找到一个应该需要的所有回归细节,例如:

    'predict',
    'pvalues',
    'remove_data',
    'resid',
    'resid_pearson',
    'rsquared',
    'rsquared_adj',
    'save',
    'scale',
    'ssr',
    'summary',
    'summary2',
    't_test',
    't_test_pairwise',
    

    但请注意,所有这些可用结果的结构都可以不同。

    运行results.rsquared 将返回一个浮点数0.611901357827784,而运行results.pvalues 将返回一个数组array([9.95834884e-01, 4.59734574e-05])。这又将分别通过results.pvalues[0]results.pvalues[1] 成为常​​数和趋势线的子集。

    有了这些信息,您可以提取其中的一些并将它们作为注释包括在内,以进一步改进您的情节图:

    剧情:

    完整代码:

    import plotly.graph_objects as go
    import plotly.express as px
    import pandas as pd
    import numpy as np
    import datetime
    
    # data
    np.random.seed(123)
    numdays=20
    X = (np.random.randint(low=-20, high=20, size=numdays).cumsum()+100).tolist()
    Y = (np.random.randint(low=-20, high=20, size=numdays).cumsum()+100).tolist()
    
    df = pd.DataFrame({'X': X, 'Y':Y})
    
    # Figure using plotly express
    fig = px.scatter(df, x="X", y="Y", trendline="ols", template = 'plotly_dark')
    
    # retrieve model estimates
    model = px.get_trendline_results(fig)
    results = model.iloc[0]["px_fit_results"]
    alpha = results.params[0]
    beta = results.params[1]
    p_beta = results.pvalues[1]
    r_squared = results.rsquared
    
    line1 = 'y = ' + str(round(alpha, 4)) + ' + ' + str(round(beta, 4))+'x'
    line2 = 'p-value = ' + '{:.5f}'.format(p_beta)
    line3 = 'R^2 = ' + str(round(r_squared, 3))
    summary = line1 + '<br>' + line2 + '<br>' + line3
    
    
    fig.add_annotation(
            x=110,
            y=140,
            xref="x",
            yref="y",
            text=summary,
            showarrow=False,
            font=dict(
                family="Courier New, monospace",
                size=16,
                color="#ffffff"
                ),
            align="left",
            arrowhead=2,
            arrowsize=1,
            arrowwidth=2,
            arrowcolor="#636363",
            ax=20,
            ay=-30,
            borderwidth=2,
            borderpad=4,
            bgcolor="rgba(100,100,100, 0.6)",
            opacity=0.8
            )
    
    fig.show()
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-11-04
      • 1970-01-01
      • 2021-10-21
      • 2020-11-01
      • 2020-04-24
      • 1970-01-01
      • 2022-12-06
      • 2020-08-11
      相关资源
      最近更新 更多