【问题标题】:Annotating seaborn regplot parameters to the plot将 seaborn regplot 参数注释到图上
【发布时间】:2021-05-19 15:53:54
【问题描述】:

我正在尝试使用r2prmse 值使用seaborn.regplot 制作散点图。但是下面的代码返回错误AttributeError: 'AxesSubplot' object has no attribute 'map_dataframe'

fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True)


g = sns.regplot(x='est_fmc', y='1h_surface', data=new_df, ax=axes[0])
def annotate(data, **kws):
    slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x = data['est_fmc'], y= data['1h_surface'] )
    rmse = mean_squared_error(data['est_fmc'], data['1h_surface'], squared=False)
    print(slope, intercept, rvalue, pvalue, rmse)
    ax = plt.gca()
    ax.text(.02, .9, 'r2={:.2f}, p={:.2g}, rmse = {:.2f}'.format(rvalue**2, pvalue, rmse),
            transform=ax.transAxes)
g.map_dataframe(annotate)


g = sns.regplot(x='est_fmc', y='1h_profile', data=new_df, ax = axes[1] )
def annotate(data, **kws):
    slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x = data['est_fmc'], y= data['1h_profile'] )
    rmse = mean_squared_error(data['est_fmc'], data['1h_profile'], squared=False)
    print(slope, intercept, rvalue, pvalue, rmse)
    ax = plt.gca()
    ax.text(.02, .9, 'r2={:.2f}, p={:.2g}, rmse = {:.2f}'.format(rvalue**2, pvalue, rmse),
            transform=ax.transAxes)
g.map_dataframe(annotate)

有办法解决吗?非常感谢任何帮助。

【问题讨论】:

    标签: python seaborn


    【解决方案1】:

    seaborn 的一个重要方面是difference between figure-level and axes-level functionssns.regplot 是轴级函数。它获取ax(表示子图)作为可选参数,并始终返回创建图的ax

    map_dataframe 用于处理图形级函数(创建子图网格)。它可以与relplot 等函数一起使用。请注意,图形级函数不接受 ax 作为参数,它们总是创建自己的新图形。

    在您的情况下,您可以使用ax 参数以及xy 的参数修改annotate 函数,以使其适用于两个子图。 (Python中一个重要的概念是"DRY - Don't Repeat Yourself"。)

    这是修改后的代码,从一些测试数据开始。 (进一步的改进是将调用regplot 放入annotate 函数中,将该函数重命名为“regplot_with_annotation”)。

    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    import scipy
    from sklearn.metrics import mean_squared_error
    
    def annotate(ax, data, x, y):
        slope, intercept, rvalue, pvalue, stderr = scipy.stats.linregress(x=data[x], y=data[y])
        rmse = mean_squared_error(data[x], data[y], squared=False)
        ax.text(.02, .9, f'r2={rvalue ** 2:.2f}, p={pvalue:.2g}, rmse={rmse:.2f}', transform=ax.transAxes)
    
    est_fmc = np.random.uniform(0, 10, 100)
    oneh_surface = 2 * est_fmc + np.random.normal(0, 5, 100) + 10
    oneh_profile = 3 * est_fmc + np.random.normal(0, 3, 100) + 5
    
    new_df = pd.DataFrame({'est_fmc': est_fmc, '1h_surface': oneh_surface, '1h_profile': oneh_profile})
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 5), sharey=True)
    
    ax = sns.regplot(x='est_fmc', y='1h_surface', data=new_df, ax=axes[0])
    annotate(ax, data=new_df, x='est_fmc', y='1h_surface')
    
    ax = sns.regplot(x='est_fmc', y='1h_profile', data=new_df, ax=axes[1])
    annotate(ax, data=new_df, x='est_fmc', y='1h_profile')
    
    plt.tight_layout()
    plt.show()
    

    【讨论】:

      猜你喜欢
      • 2018-07-27
      • 1970-01-01
      • 1970-01-01
      • 2015-08-01
      • 1970-01-01
      • 2020-03-18
      • 1970-01-01
      • 2016-01-05
      • 2017-01-02
      相关资源
      最近更新 更多