【问题标题】:Efficiently extract a patch from image and lable有效地从图像和标签中提取补丁
【发布时间】:2018-10-06 03:52:31
【问题描述】:

我有一个细分项目。我有图像和标签,其中包含分割的基本事实。图像很大,并且包含很多“空白”区域。 我想从图像和标签中剪切补丁,以便补丁中​​包含非零标签。

我需要它尽可能高效

我写了下面的代码,但是速度很慢。任何改进都将受到高度赞赏。

import numpy as np
import matplotlib.pyplot as plt
让我们创建虚拟数据
img = np.random.rand(300,200,3)
img[240:250,120:200]=0

mask = np.zeros((300,200))
mask[220:260,120:300]=0.7
mask[250:270,140:170]=0.3

f, axarr = plt.subplots(1,2, figsize = (10, 5))
axarr[0].imshow(img)
axarr[1].imshow(mask)[![enter image description here][1]][1]
plt.show()

我的低效代码:

IM_SIZE = 60     # Patch size

x_min, y_min = 0,0
x_max = img.shape[0] - IM_SIZE
y_max = img.shape[1] - IM_SIZE
xd, yd, x, y = 0,0,0,0

if (mask.max() > 0):
    xd, yd = np.where(mask>0)

    x_min = xd.min()
    y_min = yd.min()
    x_max = min(xd.max()- IM_SIZE-1, img.shape[0] - IM_SIZE-1)
    y_max = min(yd.max()- IM_SIZE-1, img.shape[1] - IM_SIZE-1)

    if (y_min >= y_max):

        y = y_max
        if (y + IM_SIZE >= img.shape[1] ): 
            print('Error')

    else:
        y = np.random.randint(y_min,y_max)

    if (x_min>=x_max):

        x = x_max
        if (x+IM_SIZE >= img.shape[0] ):
            print('Error')

    else:
        x = np.random.randint(x_min,x_max )
print(x,y)    
img = img[x:x+IM_SIZE, y:y+IM_SIZE,:]
mask = mask[x:x+IM_SIZE, y:y+IM_SIZE]

f, axarr = plt.subplots(1,2, figsize = (10, 5))
axarr[0].imshow(img)
axarr[1].imshow(mask)
plt.show()

【问题讨论】:

    标签: python performance numpy processing-efficiency


    【解决方案1】:

    line profiler 给出的结果快照如下:

    大部分时间被 mask.max() 使用(可以更改为 np.max(mask) 以加快速度)和 np.where(mask>0)。

    如果您每次都需要在不同的掩码上使用 where 函数,请查看 numexpr。或者,您可以使用joblib 来存储给定掩码的 x/y_min/max 结果,方法是并行运行许多此类案例。

    使用 numba.jit 重新排列函数会给我更好的结果:

    @jit
    def temp(mask):
        xd, yd = np.where(mask>0)
    
        x_min = np.min(xd)
        y_min = np.min(yd)
        x_max = min(np.max(xd)- IM_SIZE-1, img.shape[0] - IM_SIZE-1)
        y_max = min(np.max(yd)- IM_SIZE-1, img.shape[1] - IM_SIZE-1)
        return x_min,x_max,y_min,y_max
    
    def solver_new(img):
        IM_SIZE = 60     # Patch size
    
        x_min, y_min = 0,0
        x_max = img.shape[0] - IM_SIZE
        y_max = img.shape[1] - IM_SIZE
        xd, yd, x, y = 0,0,0,0
    
        if (np.max(mask) > 0):
            x_min,x_max,y_min,y_max = temp(mask)
            if (y_min >= y_max):
    
                y = y_max
                if (y + IM_SIZE >= img.shape[1] ): 
                    print('Error')
    
            else:
                y = np.random.randint(y_min,y_max)
    
            if (x_min>=x_max):
    
                x = x_max
                if (x+IM_SIZE >= img.shape[0] ):
                    print('Error')
    
            else:
                x = np.random.randint(x_min,x_max )
        return x,y
    

    由于图像和补丁大小很小,因此结果并没有太大意义,因为缓存对时间有很大影响。对于问题中发布的实现,我得到大约 200us 的收益,而在此处发布的实现得到大约 90us。

    【讨论】:

    • 谢谢!它大大提高了性能。 numba jit 对我来说是新的,而且 sims 非常有用。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-01-20
    • 2019-06-17
    • 2017-04-05
    • 2020-08-22
    • 2019-10-22
    • 2016-10-20
    相关资源
    最近更新 更多