【问题标题】:Plot in subplot matplotlib with custom (external) function使用自定义(外部)函数在 subplot matplotlib 中绘图
【发布时间】:2019-12-12 13:23:31
【问题描述】:

我有一个与之前已回答的问题类似的问题 (matplotlib: make plots in functions and then add each to a single subplot figure)。但是,我想要更高级的情节。我正在使用这个函数进行绘图(取自https://towardsdatascience.com/an-introduction-to-bayesian-inference-in-pystan-c27078e58d53):

def plot_trace(param, param_name='parameter', ax=None, **kwargs):
  """Plot the trace and posterior of a parameter."""

  # Summary statistics
    mean = np.mean(param)
    median = np.median(param)
    cred_min, cred_max = np.percentile(param, 2.5), np.percentile(param, 97.5)

  # Plotting
    #ax = ax or plt.gca()
    plt.subplot(2,1,1)
    plt.plot(param)
    plt.xlabel('samples')
    plt.ylabel(param_name)
    plt.axhline(mean, color='r', lw=2, linestyle='--')
    plt.axhline(median, color='c', lw=2, linestyle='--')
    plt.axhline(cred_min, linestyle=':', color='k', alpha=0.2)
    plt.axhline(cred_max, linestyle=':', color='k', alpha=0.2)
    plt.title('Trace and Posterior Distribution for {}'.format(param_name))

    plt.subplot(2,1,2)
    plt.hist(param, 30, density=True); sns.kdeplot(param, shade=True)
    plt.xlabel(param_name)
    plt.ylabel('density')
    plt.axvline(mean, color='r', lw=2, linestyle='--',label='mean')
    plt.axvline(median, color='c', lw=2, linestyle='--',label='median')
    plt.axvline(cred_min, linestyle=':', color='k', alpha=0.2, label='95% CI')
    plt.axvline(cred_max, linestyle=':', color='k', alpha=0.2)

    plt.gcf().tight_layout()
    plt.legend()

我希望这两个子图用于不同的参数。如果我使用此代码,它只会覆盖绘图并忽略 ax 参数。请您帮助我如何使其不仅适用于 2 个参数吗?

params = [mu_a,mu_tau]
param_names = ['A mean','tau mean']

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16,16))
fig.subplots_adjust(hspace=0.5)
fig.suptitle('Convergence and distribution of parameters')

plot_trace(params[0], param_name=param_names[0], ax = ax1)
plot_trace(params[1], param_name=param_names[1], ax = ax2)

预期的结果是让这些地块彼此相邻 How one subplot should look like (what the function makes)。谢谢。

【问题讨论】:

  • 您必须使用ax1.plot / ax2.plot 进行绘图,例如:matplotlib.org/3.1.1/api/_as_gen/…
  • 是的,您的函数需要采用 两个 轴并专门针对这些轴进行绘图。

标签: python matplotlib subplot


【解决方案1】:

正如@Sebastian-R 和@ImportanceOfBeingErnest 所建议的那样,问题在于您在函数内部创建子图,而不是使用传递给ax= 关键字的内容。另外,如果你想在两个不同的子图上绘图,你需要将两个轴实例传递给你的函数

有两种方法可以解决问题:

  • 第一个解决方案是重写您的函数以使用 matplotlib 的 the object-oriented API(这是我推荐的解决方案,虽然它需要更多的工作)

代码:

def plot_trace(param, axs, param_name='parameter', **kwargs):
    """Plot the trace and posterior of a parameter."""
    # Summary statistics
    (...)
    # Plotting
    ax1 = axs[0]
    ax1.plot(param)
    ax1.set_xlabel('samples')  # notice the difference in notation here
    (...)

    ax2 = axs[1]
    ax2.hist(...); sns.kdeplot(..., ax=ax2)
    (...)
  • 第二种解决方案保留您当前的代码,但由于 plt.xxx() 函数适用于当前坐标区,您需要将作为参数传递的坐标区设为当前坐标区:

代码:

def plot_trace(param, axs, param_name='parameter', **kwargs):
    """Plot the trace and posterior of a parameter."""
    # Summary statistics
    (...)
    # Plotting
    plt.sca(axs[0])
    plt.plot(param)
    plt.xlabel('samples')
    (...)

    plt.sca(axs[1])
    plt.hist(...)
    (...)

然后您需要在最后创建所需的子图数量(因此,如果您想绘制两个参数,则必须创建 4 个子图)并按照您的需要调用您的函数:

params = [mu_a,mu_tau]
param_names = ['A mean','tau mean']

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(16,16))
fig.subplots_adjust(hspace=0.5)
fig.suptitle('Convergence and distribution of parameters')

plot_trace(params[0], axs=[ax1, ax3], param_name=param_names[0])
plot_trace(params[1], axs=[ax2, ax4], param_name=param_names[1])

【讨论】:

  • 你可能想重读这个问题,以及我在它下面的评论。
  • 我明白了,我错过了函数生成两个子图的部分
  • 非常感谢 cmets 和 @Diziet Asahi 的解决方案,两者都有效。
猜你喜欢
  • 2021-01-20
  • 2019-06-24
  • 2022-01-20
  • 1970-01-01
  • 1970-01-01
  • 2013-01-08
  • 1970-01-01
  • 1970-01-01
  • 2020-12-08
相关资源
最近更新 更多