【问题标题】:Matplotlib: save plot to numpy arrayMatplotlib:将绘图保存到 numpy 数组
【发布时间】:2011-12-10 21:55:50
【问题描述】:

在 Python 和 Matplotlib 中,很容易将绘图显示为弹出窗口或将绘图保存为 PNG 文件。如何将绘图保存到 RGB 格式的 numpy 数组中?

【问题讨论】:

    标签: python numpy matplotlib


    【解决方案1】:

    当您需要与保存的绘图进行逐像素比较时,这对于单元测试等来说是一个方便的技巧。

    一种方法是使用fig.canvas.tostring_rgb,然后使用numpy.fromstring 和适当的数据类型。还有其他方法,但这是我倾向于使用的方法。

    例如

    import matplotlib.pyplot as plt
    import numpy as np
    
    # Make a random plot...
    fig = plt.figure()
    fig.add_subplot(111)
    
    # If we haven't already shown or saved the plot, then we need to
    # draw the figure first...
    fig.canvas.draw()
    
    # Now we can save it to a numpy array.
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    

    【讨论】:

    • 适用于 Agg,在 import matplotlib.pyplot as plt 之前添加 matplotlib.use('agg') 以使用它。
    • 对于图像,画布增加了很大的边距,所以我发现在绘图之前插入fig.tight_layout(pad=0) 很有用。
    • 对于带有线条和文字的图形,关闭抗锯齿功能也很重要。对于行plt.setp([ax.get_xticklines() + ax.get_yticklines() + ax.get_xgridlines() + ax.get_ygridlines()],antialiased=False) 和文本mpl.rcParams['text.antialiased']=False
    • @JoeKington np.fromstringsep='' 自版本 1.14 起已弃用。在以后的版本中应该替换为data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    • 万一遇到'FigureCanvasGTKAgg' object has no attribute 'renderer',记得matplotlib.use('Agg')stackoverflow.com/a/35407794/5339857
    【解决方案2】:

    有人提出这样的方法

    np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    

    当然,这段代码有效。但是,输出的 numpy 数组图像的分辨率太低了。

    我的提案代码是这个。

    import io
    import cv2
    import numpy as np
    import matplotlib.pyplot as plt
    
    # plot sin wave
    fig = plt.figure()
    ax = fig.add_subplot(111)
    
    x = np.linspace(-np.pi, np.pi)
    
    ax.set_xlim(-np.pi, np.pi)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    
    ax.plot(x, np.sin(x), label="sin")
    
    ax.legend()
    ax.set_title("sin(x)")
    
    
    # define a function which returns an image as numpy array from figure
    def get_img_from_fig(fig, dpi=180):
        buf = io.BytesIO()
        fig.savefig(buf, format="png", dpi=dpi)
        buf.seek(0)
        img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
        buf.close()
        img = cv2.imdecode(img_arr, 1)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
        return img
    
    # you can get a high-resolution image as numpy array!!
    plot_img_np = get_img_from_fig(fig)
    

    此代码运行良好。
    如果在 dpi 参数上设置大数字,则可以将高分辨率图像作为 numpy 数组。

    【讨论】:

    • 我建议将 import 语句与函数一起添加。
    • @AnshulRai 谢谢你的好建议!!我添加了关于导入、绘图以及如何使用该函数的代码。
    【解决方案3】:

    @JUN_NETWORKS 的回答有一个更简单的选项。除了将图形保存为png,还可以使用其他格式,例如rawrgba,并跳过cv2 解码步骤。

    换句话说,实际的 plot-to-numpy 转换归结为:

    io_buf = io.BytesIO()
    fig.savefig(io_buf, format='raw', dpi=DPI)
    io_buf.seek(0)
    img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
                         newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
    io_buf.close()
    
    

    希望,这会有所帮助。

    【讨论】:

    • 我认为这个答案远远优于上面的答案:1)它产生高分辨率图像,2)不依赖于像 cv2 这样的外部包。
    • 我收到一个重塑错误“无法将大小为 3981312 的数组重塑为形状 (480,640,newaxis)”。有什么想法吗?
    • 确实这个答案正是我想要的!谢谢!
    • @FabianHertwig 我有同样的问题,这里是修复。创建无花果时,您必须将 dpi 设置为与保存时相同 fig = plt.figure(figsize=(16, 4), dpi=128) 然后 fig.savefig(io_buf, format='raw', dpi=128)
    • 这在省略dpi 参数时对我有用,即fig.savefig(io_buf, format='raw')
    【解决方案4】:

    如果有人想要一个即插即用的解决方案,而不修改任何先前的代码(获取对 pyplot 图形的引用和所有),下面的内容对我有用。只需在所有 pyplot 语句之后添加这个,即在 pyplot.show() 之前添加这个

    canvas = pyplot.gca().figure.canvas
    canvas.draw()
    data = numpy.frombuffer(canvas.tostring_rgb(), dtype=numpy.uint8)
    image = data.reshape(canvas.get_width_height()[::-1] + (3,))
    

    【讨论】:

      【解决方案5】:

      是时候对您的解决方案进行基准测试了。

      import io
      import matplotlib
      matplotlib.use('agg')  # turn off interactive backend
      import matplotlib.pyplot as plt
      import numpy as np
      
      fig, ax = plt.subplots()
      ax.plot(range(10))
      
      
      def plot1():
          fig.canvas.draw()
          data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
          w, h = fig.canvas.get_width_height()
          im = data.reshape((int(h), int(w), -1))
      
      
      def plot2():
          with io.BytesIO() as buff:
              fig.savefig(buff, format='png')
              buff.seek(0)
              im = plt.imread(buff)
      
      
      def plot3():
          with io.BytesIO() as buff:
              fig.savefig(buff, format='raw')
              buff.seek(0)
              data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
          w, h = fig.canvas.get_width_height()
          im = data.reshape((int(h), int(w), -1))
      
      >>> %timeit plot1()
      34 ms ± 4.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
      >>> %timeit plot2()
      50.2 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
      >>> %timeit plot3()
      16.4 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
      

      在这种情况下,IO 原始缓冲区是将 matplotlib 图形转换为 numpy 数组的最快速度。

      补充说明:

      • 如果您无法访问该图,您可以随时从坐标区中提取它:

        fig = ax.figure

      • 如果您需要channel x height x width 格式的数组,请执行

        im = im.transpose((2, 0, 1)).

      【讨论】:

      • 为什么这么多 stackoverflowers 痴迷于基准测试?
      【解决方案6】:

      MoviePy 使将图形转换为 numpy 数组变得非常简单。它有一个名为mplfig_to_npimage() 的内置函数。你可以这样使用它:

      from moviepy.video.io.bindings import mplfig_to_npimage
      import matplotlib.pyplot as plt
      
      fig = plt.figure()  # make a figure
      numpy_fig = mplfig_to_npimage(fig)  # convert it to a numpy array
      

      【讨论】:

        【解决方案7】:

        正如 Joe Kington 所指出的,一种方法是在画布上绘图,将画布转换为字节字符串,然后将其重塑为正确的形状。

        import matplotlib.pyplot as plt
        import numpy as np
        import math
        
        plt.switch_backend('Agg')
        
        
        def canvas2rgb_array(canvas):
            """Adapted from: https://stackoverflow.com/a/21940031/959926"""
            canvas.draw()
            buf = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
            ncols, nrows = canvas.get_width_height()
            scale = round(math.sqrt(buf.size / 3 / nrows / ncols))
            return buf.reshape(scale * nrows, scale * ncols, 3)
        
        
        # Make a simple plot to test with
        t = np.arange(0.0, 2.0, 0.01)
        s = 1 + np.sin(2 * np.pi * t)
        fig, ax = plt.subplots()
        ax.plot(t, s)
        
        # Extract the plot as an array
        plt_array = canvas2rgb_array(fig.canvas)
        print(plt_array.shape)
        

        但是,由于 canvas.get_width_height() 在显示坐标中返回宽度和高度,因此有时会在此答案中解决缩放问题。

        【讨论】:

          猜你喜欢
          • 2023-03-15
          • 2017-09-07
          • 1970-01-01
          • 2020-07-07
          • 1970-01-01
          • 2020-10-24
          • 1970-01-01
          • 2021-08-22
          • 1970-01-01
          相关资源
          最近更新 更多