【问题标题】:How to convert this function into a generator?如何将此函数转换为生成器?
【发布时间】:2019-11-29 12:39:01
【问题描述】:

我有一个函数(patch_generator),它接受两个图像(作为 numpy 数组)并生成补丁,给出一个 window_size。该功能正常工作,但它同时生成所有补丁。我想批量生成图像补丁。

def patch_generator(x, y, window_size):
    # x and y are numpy arrays with shape of: (bands, height, width)
    # generates image patch of shape (patches, size, size, bands)
    # and a ground truth patch of shape (patches,)
    index_patch = 0
    x = x.reshape((x.shape[1], x.shape[2], x.shape[0]))
    y = y.reshape((y.shape[1], y.shape[2], y.shape[0]))
    if window_size % 2 == 0:
        margin = int(window_size/2)
    else:
        margin = int((window_size - 1) / 2)
    x_zeros = pad_zeros(x, margin)
    x_patch = np.zeros((x.shape[0]*x.shape[1], window_size, window_size, x.shape[2]))
    y_patch = np.zeros((x.shape[0]*x.shape[1]))
    row_range = range(margin, x_zeros.shape[0] - margin)
    col_range = range(margin, x_zeros.shape[1] - margin)
    for r in product(row_range, col_range):
        if window_size % 2 == 0:
            patch = x_zeros[r[0] - margin:r[0] + margin,
                            r[1] - margin:r[1] + margin]
        else:
            patch = x_zeros[r[0] - margin:r[0]+margin+1,
                            r[1] - margin:r[1]+margin+1]
        x_patch[index_patch, :, :, :] = patch
        y_patch[index_patch] = y[r[0]-margin, r[1]-margin]
        index_patch += 1
    del margin, x_zeros, row_range, col_range, patch
    return x_patch.astype(np.float16), y_patch.astype(np.float16)

我希望通过包含一个额外的函数参数“batch_size”将相同的函数转换为生成器。这应该每次都会生成 batch_size 图像块。

def pad_zeros(x, margin=2):
    # x is a numpy array with shape of: (width, height, bands)
    new_x = np.zeros((x.shape[0]+2*margin, x.shape[1]+2*margin, x.shape[2]))
    new_x[margin:x.shape[0]+margin, margin:x.shape[1]+margin, :] = x
    return new_x

例如,

X = np.array([[0, 2, 3, 4, 5, 6, 7 , 8],
              [1, 1, 1, 1, 1, 1, 1 , 1],
              [2, 2, 3, 4, 0, 6, 5 , 4],
              [0, 2, 3, 4, 5, 6, 7 , 8],
              [3, 2, 1, 3, 0, 9, 1 , 0],
              [1, 1, 3, 4, 5, 7, 6 , 8],
              [0, 3, 3, 4, 5, 6, 1 , 0]).reshape(1, 7, 8)

Y = np.array([[0, 0, 0, 0, 0, 0, 0 , 0],
              [1, 1, 1, 1, 1, 1, 1 , 1],
              [0, 0, 0, 0, 0, 0, 0 , 0],
              [0, 0, 0, 0, 0, 0, 0 , 0],
              [0, 0, 1, 0, 0, 0, 1 , 0],
              [1, 1, 0, 0, 0, 0, 0 , 0],
              [0, 0, 0, 0, 0, 0, 1 , 0]).reshape(1, 7, 8)

【问题讨论】:

    标签: python image-processing


    【解决方案1】:

    也许是这样的?基本上要制作一个生成器,你需要使用关键字yield 而不是return。有关示例,请参阅此 blog

    def patch_generator(x, y, window_size, batch_size=64):
        # x and y are numpy arrays with shape of: (bands, height, width)
        # generates image patch of shape (patches, size, size, bands)
        # and a ground truth patch of shape (patches,)
    
        x = x.reshape((x.shape[1], x.shape[2], x.shape[0]))
        y = y.reshape((y.shape[1], y.shape[2], y.shape[0]))
        if window_size % 2 == 0:
            margin = int(window_size/2)
        else:
            margin = int((window_size - 1) / 2)
        x_zeros = pad_zeros(x, margin)
    
        row_range = range(margin, x_zeros.shape[0] - margin)
        col_range = range(margin, x_zeros.shape[1] - margin)
    
        # Prepare vectors with batch_size entries
        x_patch = np.zeros((batch_size, window_size, window_size, x.shape[2]))
        y_patch = np.zeros(batch_size)
    
        for index_patch,r in enumerate(product(row_range, col_range)):
    
            if window_size % 2 == 0:
                patch = x_zeros[r[0] - margin:r[0] + margin,
                                r[1] - margin:r[1] + margin]
            else:
                patch = x_zeros[r[0] - margin:r[0]+margin+1,
                                r[1] - margin:r[1]+margin+1]
    
            # Overwrite x_patch and y_patch at each new batch
            x_patch[index_patch % batch_size, :, :, :] = patch
            y_patch[index_patch % batch_size] = y[r[0]-margin, r[1]-margin]
    
            if ((index_patch+1) % batch_size == 0) or (index_patch==len(row_range)*len(col_range)-1):
                # When we reach batch_size, return the result
                yield (x_patch.astype(np.float16),
                       y_patch.astype(np.float16))
    

    对于测试(patch_fn 是您的原始函数,patch_generator 是修改后的生成器版本):

    x = np.array([[0, 2, 3, 4, 5, 6, 7 , 8],
                  [1, 1, 1, 1, 1, 1, 1 , 1],
                  [2, 2, 3, 4, 0, 6, 5 , 4],
                  [0, 2, 3, 4, 5, 6, 7 , 8],
                  [3, 2, 1, 3, 0, 9, 1 , 0],
                  [1, 1, 3, 4, 5, 7, 6 , 8],
                  [0, 3, 3, 4, 5, 6, 1 , 0]]).reshape(1, 7, 8)
    
    y = np.array([[0, 0, 0, 0, 0, 0, 0 , 0],
                  [1, 1, 1, 1, 1, 1, 1 , 1],
                  [0, 0, 0, 0, 0, 0, 0 , 0],
                  [0, 0, 0, 0, 0, 0, 0 , 0],
                  [0, 0, 1, 0, 0, 0, 1 , 0],
                  [1, 1, 0, 0, 0, 0, 0 , 0],
                  [0, 0, 0, 0, 0, 0, 1 , 0]]).reshape(1, 7, 8)
    
    window_size = 3
    batch_size = 2
    
    res_w_fn = patch_fn(x,y,window_size)
    res_w_gen = patch_generator(x,y,window_size,batch_size)
    
    for i,(x,y) in enumerate(res_w_gen):
        x_t,y_t = (res_w_fn[0][i*batch_size:(i+1)*batch_size,:,:,:],
                   res_w_fn[1][i*batch_size:(i+1)*batch_size])
        if np.sum(x!=x_t)==0 and np.sum(y!=y_t)==0:
            print("Batch #{} - Success".format(i))
        else:
            print("Batch #{} - Fail".format(i))
    

    【讨论】:

    • 感谢 Nakor,它有效。让我在不同的情况下检查一下,希望不会发现任何问题。
    • 最后不会生成剩余的补丁。当我设置 batch_size = 1 时,它会正确生成所有补丁。但是当我选择示例 3 时,最后一个补丁被忽略了。我想生成所有的补丁,包括最后一批比选定的 batch_size 小。
    • 因为组合数不是batch_size的倍数。我编辑了我的代码以考虑到这一点,让我知道它现在是否工作得更好
    • 它会生成额外的补丁。在 3 的情况下,它会产生额外的 1,并随着批量大小的增加而增加。
    • 你到底是什么意思?它生成所有正确的补丁,最后添加错误的补丁?
    猜你喜欢
    • 2018-12-21
    • 1970-01-01
    • 1970-01-01
    • 2021-01-03
    • 2023-03-24
    • 2015-05-30
    • 1970-01-01
    • 2020-06-07
    • 2017-06-01
    相关资源
    最近更新 更多