答案:
model = px.get_trendline_results(fig)
results = model.iloc[0]["px_fit_results"]
alpha = params[0]
beta = .params[1]
p_beta = .pvalues[1]
r_squared = .rsquared
详情:
所有回归结果均可通过以下方式获得:
px.get_trendline_results(fig)
它在运行时会返回一个看起来有点神秘的熊猫数据框:
px_fit_results
0 <statsmodels.regression.linear_model.Regressio...
px_fit_results 下的元素是 statsmodels.regression.linear_model.RegressionResultsWrapper 类型的对象,它是 statsmodels 的包装器。
所以如果我们通过设置来简化一下:
models = px.get_trendline_results(fig)
还有:
results = model.iloc[0]["px_fit_results"]
然后我们可以使用以下方法检查该对象中可用的内容:
dir(results)
并找到一个应该需要的所有回归细节,例如:
'predict',
'pvalues',
'remove_data',
'resid',
'resid_pearson',
'rsquared',
'rsquared_adj',
'save',
'scale',
'ssr',
'summary',
'summary2',
't_test',
't_test_pairwise',
但请注意,所有这些可用结果的结构都可以不同。
运行results.rsquared 将返回一个浮点数0.611901357827784,而运行results.pvalues 将返回一个数组array([9.95834884e-01, 4.59734574e-05])。这又将分别通过results.pvalues[0] 和results.pvalues[1] 成为常数和趋势线的子集。
有了这些信息,您可以提取其中的一些并将它们作为注释包括在内,以进一步改进您的情节图:
剧情:
完整代码:
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import numpy as np
import datetime
# data
np.random.seed(123)
numdays=20
X = (np.random.randint(low=-20, high=20, size=numdays).cumsum()+100).tolist()
Y = (np.random.randint(low=-20, high=20, size=numdays).cumsum()+100).tolist()
df = pd.DataFrame({'X': X, 'Y':Y})
# Figure using plotly express
fig = px.scatter(df, x="X", y="Y", trendline="ols", template = 'plotly_dark')
# retrieve model estimates
model = px.get_trendline_results(fig)
results = model.iloc[0]["px_fit_results"]
alpha = results.params[0]
beta = results.params[1]
p_beta = results.pvalues[1]
r_squared = results.rsquared
line1 = 'y = ' + str(round(alpha, 4)) + ' + ' + str(round(beta, 4))+'x'
line2 = 'p-value = ' + '{:.5f}'.format(p_beta)
line3 = 'R^2 = ' + str(round(r_squared, 3))
summary = line1 + '<br>' + line2 + '<br>' + line3
fig.add_annotation(
x=110,
y=140,
xref="x",
yref="y",
text=summary,
showarrow=False,
font=dict(
family="Courier New, monospace",
size=16,
color="#ffffff"
),
align="left",
arrowhead=2,
arrowsize=1,
arrowwidth=2,
arrowcolor="#636363",
ax=20,
ay=-30,
borderwidth=2,
borderpad=4,
bgcolor="rgba(100,100,100, 0.6)",
opacity=0.8
)
fig.show()