【问题标题】:How to remove the space between subplots in matplotlib.pyplot?如何删除 matplotlib.pyplot 中子图之间的空间?
【发布时间】:2017-04-25 14:42:18
【问题描述】:

我正在做一个项目,我需要将一个 10 行和 3 列的绘图网格放在一起。虽然我已经能够制作情节并安排子情节,但我无法制作一个没有空白的漂亮情节,例如下面来自gridspec documentatation. 的这个。

我尝试了以下帖子,但仍然无法完全删除示例图像中的空白。有人可以给我一些指导吗?谢谢!

这是我的图片:

下面是我的代码。 The full script is here on GitHub。 注意:images_2 和 images_fool 都是形状为 (1032, 10) 的扁平图像的 numpy 数组,而 delta 是形状为 (28, 28) 的图像数组。

def plot_im(array=None, ind=0):
    """A function to plot the image given a images matrix, type of the matrix: \
    either original or fool, and the order of images in the matrix"""
    img_reshaped = array[ind, :].reshape((28, 28))
    imgplot = plt.imshow(img_reshaped)

# Output as a grid of 10 rows and 3 cols with first column being original, second being
# delta and third column being adversaril
nrow = 10
ncol = 3
n = 0

from matplotlib import gridspec
fig = plt.figure(figsize=(30, 30)) 
gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1]) 

for row in range(nrow):
    for col in range(ncol):
        plt.subplot(gs[n])
        if col == 0:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_2, ind=row)
        elif col == 1:
            #plt.subplot(nrow, ncol, n)
            plt.imshow(w_delta)
        else:
            #plt.subplot(nrow, ncol, n)
            plot_im(array=images_fool, ind=row)
        n += 1

plt.tight_layout()
#plt.show()
plt.savefig('grid_figure.pdf')

【问题讨论】:

    标签: python numpy matplotlib


    【解决方案1】:

    尝试将这一行添加到您的代码中:

    fig.subplots_adjust(wspace=0, hspace=0)
    

    对于每一个轴对象集:

    ax.set_xticklabels([])
    ax.set_yticklabels([])
    

    【讨论】:

    • 此解决方案适用于纵横比设置为auto 的子图。对于使用imshow 绘制的图像,如此处的用例所示,它将失败。查看我的解决方案。
    • 感谢您的回复。这确实删除了垂直空间,但水平空间仍然存在......
    【解决方案2】:

    开头的注释:如果您想完全控制间距,请避免使用plt.tight_layout(),因为它会尝试将图中的图排列成均匀且均匀分布。这大部分都很好并且产生了令人愉快的结果,但可以随意调整间距。

    您从 Matplotlib 示例库中引用的 GridSpec 示例运行良好的原因是子图的方面没有预定义。也就是说,子图将简单地在网格上展开,并让设置的间距(在本例中为 wspace=0.0, hspace=0.0)与图形大小无关。

    与此相反,您使用imshow 绘制图像,并且图像的纵横比默认设置为相等(相当于ax.set_aspect("equal"))。也就是说,您当然可以将set_aspect("auto") 添加到每个绘图(并另外添加wspace=0.0, hspace=0.0 作为 GridSpec 的参数,如画廊示例中所示),这将生成一个没有间距的绘图。

    但是,在使用图像时,保持相同的纵横比是很有意义的,这样每个像素都一样宽,并且正方形阵列显示为正方形图像。
    然后你需要做的是使用图像大小和图形边距来获得预期的结果。 figure 的 figsize 参数是以英寸为单位的图形(宽度,高度),这里可以使用两个数字的比率。并且可以手动调整子图参数wspace, hspace, top, bottom, left 以获得所需的结果。 下面是一个例子:

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib import gridspec
    
    nrow = 10
    ncol = 3
    
    fig = plt.figure(figsize=(4, 10)) 
    
    gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1],
             wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845) 
    
    for i in range(10):
        for j in range(3):
            im = np.random.rand(28,28)
            ax= plt.subplot(gs[i,j])
            ax.imshow(im)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
    
    #plt.tight_layout() # do not use this!!
    plt.show()
    

    编辑:
    不必手动调整参数当然是可取的。因此可以根据行数和列数计算出一些最佳值。

    nrow = 7
    ncol = 7
    
    fig = plt.figure(figsize=(ncol+1, nrow+1)) 
    
    gs = gridspec.GridSpec(nrow, ncol,
             wspace=0.0, hspace=0.0, 
             top=1.-0.5/(nrow+1), bottom=0.5/(nrow+1), 
             left=0.5/(ncol+1), right=1-0.5/(ncol+1)) 
    
    for i in range(nrow):
        for j in range(ncol):
            im = np.random.rand(28,28)
            ax= plt.subplot(gs[i,j])
            ax.imshow(im)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
    
    plt.show()
    

    【讨论】:

    • 像魔术一样工作,感谢@ImportanceOfBeingErnest!只是想知道为什么你使用 figsize=(4, 10) 而不是 figsize=(10, 10)...后者立即带回空间。
    • 如果行数是列数的 3 倍,为什么还需要正方形?您当然可以将其设置为 (10,10),然后再次调整 leftright 参数。我对figsize=(4, 10) 的选择更多地取决于具有n 行和m 列的想法,(m+1, n) 的 figsize 可能是合适的;然后通过微调子图参数来完成其余的工作。
    • 我明白了。所以“figsize”实际上是指整体图像大小而不是子图。我自己弄糊涂了。
    • figsize 是以英寸为单位的图形大小。我已经编辑了答案,以包含一个根据列数和行数计算参数的示例。
    • 对于那些使用图像数组的人,将gen = iter(image_array)放在for循环之外,ax.imshow(next(gen, np.full((1, 1, 3), 255)))放在里面。当gen 的值用完时,它会渲染一个大小为 1x1 的白色图像
    【解决方案3】:

    按照 ImportanceOfBeingErnest 的回答,但如果您想使用 plt.subplots 及其功能:

    fig, axes = plt.subplots(
        nrow, ncol,
        gridspec_kw=dict(wspace=0.0, hspace=0.0,
                         top=1. - 0.5 / (nrow + 1), bottom=0.5 / (nrow + 1),
                         left=0.5 / (ncol + 1), right=1 - 0.5 / (ncol + 1)),
        figsize=(ncol + 1, nrow + 1),
        sharey='row', sharex='col', #  optionally
    )
    

    【讨论】: