【问题标题】:matplotlib subplots with no spacing, restricted figure size and tight_layout()matplotlib 子图没有间距,限制图形大小和tight_layout()
【发布时间】:2020-11-25 11:10:09
【问题描述】:

我正在尝试制作一个包含两个共享 x 轴的子图并且它们之间没有空间的图。我遵循了 matplotlib 库中的 Create adjacent subplots 示例。然而,我的情节需要有一个固定的大小,这使得一切变得复杂。如果我只是按照示例添加固定大小的图形大小,则标签被切断。如果我使用tight_layout 包含标签,那么这些图是间隔的。如何解决这个问题?此外,标题应该更接近图例。 非常感谢任何帮助!

示例程序,注释掉tight_layout看看有什么区别。

import numpy as np                                                                                             
import matplotlib.pyplot as plt                                                                                
                                                                                                               
x_min = -2*np.pi                                                                                               
x_max = 2*np.pi                                                                                                
resolution = 101                                                                                               
x_vals = np.linspace(x_min, x_max, resolution)                                                                 
y_upper = np.cos(x_vals)                                                                                       
y_lower = -np.cos(x_vals)                                                                                      
data3 = np.sin(x_vals)                                                                                         
                                                                                                               
fig = plt.figure(figsize=(80/25.4, 80/25.4))  # figsize is needed for later usage of the plot                  
ax = fig.subplots(2, 1, sharex=True)                                                                           
fig.subplots_adjust(hspace=0)                                                                                  
                                                                                                               
ax[0].plot(x_vals, y_upper, label="data 1")                                                                    
ax[0].plot(x_vals, y_lower, label="data 2")                                                                    
                                                                                                               
ax[1].set_xlim([x_min,x_max])                                                                                  
ax[0].set_ylim([-1.6,1.6])                                                                                     
ax[1].set_ylim([-1.3,1.3])                                                                                     
                                                                                                               
ax[1].plot(x_vals, data3, ls='-', label="data 3", color='C2')                                                  
                                                                                                               
ax[1].set_xlabel("xaxis")                                                                                      
ax[0].set_ylabel("yaxis 1")                                                                                    
ax[1].set_ylabel("yaxis 2")                                                                                    
ax[0].legend(bbox_to_anchor=(0, 1.02, 1., 0.102), loc='lower left', ncol=2, mode="expand", borderaxespad=0)    
                                                                                                               
fig.suptitle("Title")                                                                                          
fig.tight_layout()  # comment this out to see the difference                                                   
# fig.savefig('figure.png')                                                                                    
plt.show()

【问题讨论】:

    标签: python matplotlib


    【解决方案1】:

    您需要使用GridSpec 而不是subplots_adjust(),这样tight_layout() 会知道您想要零空间并保持这种状态。

    其实你在使用fig.subplots()的时候已经创建了GridSpec,所以你只需要在gridspec_kw=中传递一些额外的参数

    x_min = -2*np.pi                                                                                               
    x_max = 2*np.pi                                                                                                
    resolution = 101                                                                                               
    x_vals = np.linspace(x_min, x_max, resolution)                                                                 
    y_upper = np.cos(x_vals)                                                                                       
    y_lower = -np.cos(x_vals)                                                                                      
    data3 = np.sin(x_vals)                                                                                         
                                                                                                                   
    fig = plt.figure(figsize=(80/25.4, 80/25.4))  # figsize is needed for later usage of the plot             
    #
    # This is the line that changes. Instruct the gridspec to have zero vertical pad
    #     
    ax = fig.subplots(2, 1, sharex=True, gridspec_kw=dict(hspace=0))                                                                           
                                                                                              
    ax[0].plot(x_vals, y_upper, label="data 1")                                                                    
    ax[0].plot(x_vals, y_lower, label="data 2")                                                                    
                                                                                                                   
    ax[1].set_xlim([x_min,x_max])                                                                                  
    ax[0].set_ylim([-1.6,1.6])                                                                                     
    ax[1].set_ylim([-1.3,1.3])                                                                                     
                                                                                                                   
    ax[1].plot(x_vals, data3, ls='-', label="data 3", color='C2')                                                  
                                                                                                                   
    ax[1].set_xlabel("xaxis")                                                                                      
    ax[0].set_ylabel("yaxis 1")                                                                                    
    ax[1].set_ylabel("yaxis 2")                                                                                    
    ax[0].legend(bbox_to_anchor=(0, 1.02, 1., 0.102), loc='lower left', ncol=2, mode="expand", borderaxespad=0)    
                                                                                                                   
    fig.suptitle("Title")                                                                                          
    fig.tight_layout()  # Now tight_layout does not add padding between axes
    # fig.savefig('figure.png')                                                                                    
    plt.show()
    

    【讨论】:

    • 这很好用,而且看起来很健壮。为了使标题更接近传说,我使用了fig.suptitle("Title", y=0.89)。不是最好的方法,但结果看起来不错。
    【解决方案2】:

    使用子图获得精确结果可能会令人沮丧 - 使用 gridspec (https://matplotlib.org/3.3.3/tutorials/intermediate/gridspec.html) 会提高您的精度。

    但是,考虑到您所在的位置,我认为您可以通过以下方式获得想要的东西:

    import matplotlib.pyplot as plt                                                                                
                                                                                                                   
    x_min = -2*np.pi                                                                                               
    x_max = 2*np.pi                                                                                                
    resolution = 101                                                                                               
    x_vals = np.linspace(x_min, x_max, resolution)                                                                 
    y_upper = np.cos(x_vals)                                                                                       
    y_lower = -np.cos(x_vals)                                                                                      
    data3 = np.sin(x_vals)                                                                                         
                                                                                                                   
    fig = plt.figure(figsize=(80/25.4, 80/25.4))  # figsize is needed for later usage of the plot                  
    ax = fig.subplots(3, 1, sharex=True)                                                                           
    fig.subplots_adjust(hspace=0)                                                                                  
    ax[0].text(0,0.5,"Title", ha='center')
    ax[0].axis("off")
    ax[1].plot(x_vals, y_upper, label="data 1")                                                                    
    ax[1].plot(x_vals, y_lower, label="data 2")                                                                    
                                                                                                                   
    ax[2].set_xlim([x_min,x_max])                                                                                  
    ax[1].set_ylim([-1.6,1.6])                                                                                     
    ax[2].set_ylim([-1.3,1.3])                                                                                     
                                                                                                                   
    ax[2].plot(x_vals, data3, ls='-', label="data 3", color='C2')                                                  
                                                                                                                   
    ax[2].set_xlabel("xaxis")                                                                                      
    ax[1].set_ylabel("yaxis 1")                                                                                    
    ax[2].set_ylabel("yaxis 2")                                                                                    
    ax[1].legend(bbox_to_anchor=(0, 1.02, 1., 0.102), loc='lower left', ncol=2, mode="expand", borderaxespad=0)    
    
    #fig.tight_layout()  # comment this out to see the difference                                                   
    # fig.savefig('figure.png')                                                                                    
    plt.show()
    

    【讨论】:

    • 使用 ax.text 而不是 fig.suptitle 确实可以将标题带到所需的位置。但是,当我运行您的代码时,轴标签会被切断。也许不同的matplotlib版本?我使用 3.3.3
    【解决方案3】:

    当然,gridspec 是正确的方法,如果您处于脚本编写的早期阶段,you should adapt this。但是,如果您想轻松解决,也可以移动fig.subplots_adjust()

    #...
    fig.suptitle("Title")                                                                                          
    fig.tight_layout()    
    fig.subplots_adjust(hspace=0)                                            
    # fig.savefig('figure.png')                                                                                    
    plt.show()
    

    已保存图片:

    【讨论】:

    • 非常感谢,我只是将其用作快速修复。但是,必须注意我装饰情节的顺序似乎不太理想,因此将来我将使用上面提供的 gridspec 解决方案并接受它作为答案。