【问题标题】:Overlapped predictions on segmented image分割图像的重叠预测
【发布时间】:2019-09-26 09:01:47
【问题描述】:

症状的背景和例子

我正在使用神经网络进行超分辨率(提高图像的分辨率)。但是,由于图像可能很大,因此我需要将其分割成多个较小的图像,并在将结果重新组合在一起之前分别对每个图像进行预测。

以下是这给我的例子:

示例 1:您可以在输出图片中看到一条穿过滑雪者肩部的细微垂直线。

示例 2:一旦您开始看到它们,您会注意到细微的线条在整个图像中形成正方形(我分割图像以进行单独预测的方式的残余)。

示例 3:可以清楚地看到穿过湖面的垂直线。


问题的根源

基本上,我的网络对边缘的预测很差,我认为这是正常的,因为“周围”信息较少。


源代码

import numpy as np
import matplotlib.pyplot as plt
import skimage.io

from keras.models import load_model

from constants import verbosity, save_dir, overlap, \
    model_name, tests_path, input_width, input_height
from utils import float_im

def predict(args):
    model = load_model(save_dir + '/' + args.model)

    image = skimage.io.imread(tests_path + args.image)[:, :, :3]  # removing possible extra channels (Alpha)
    print("Image shape:", image.shape)

    predictions = []
    images = []

    crops = seq_crop(image)  # crops into multiple sub-parts the image based on 'input_' constants

    for i in range(len(crops)):  # amount of vertical crops
        for j in range(len(crops[0])):  # amount of horizontal crops
            current_image = crops[i][j]
            images.append(current_image)

    print("Moving on to predictions. Amount:", len(images))

    for p in range(len(images)):
        if p%3 == 0 and verbosity == 2:
            print("--prediction #", p)
        # Hack because GPU can only handle one image at a time
        input_img = (np.expand_dims(images[p], 0))       # Add the image to a batch where it's the only member
        predictions.append(model.predict(input_img)[0])  # returns a list of lists, one for each image in the batch

    return predictions, image, crops


def show_pred_output(input, pred):
    plt.figure(figsize=(20, 20))
    plt.suptitle("Results")

    plt.subplot(1, 2, 1)
    plt.title("Input : " + str(input.shape[1]) + "x" + str(input.shape[0]))
    plt.imshow(input, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)

    plt.subplot(1, 2, 2)
    plt.title("Output : " + str(pred.shape[1]) + "x" + str(pred.shape[0]))
    plt.imshow(pred, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)

    plt.show()


# adapted from  https://stackoverflow.com/a/52463034/9768291
def seq_crop(img):
    """
    To crop the whole image in a list of sub-images of the same size.
    Size comes from "input_" variables in the 'constants' (Evaluation).
    Padding with 0 the Bottom and Right image.
    :param img: input image
    :return: list of sub-images with defined size
    """
    width_shape = ceildiv(img.shape[1], input_width)
    height_shape = ceildiv(img.shape[0], input_height)
    sub_images = []  # will contain all the cropped sub-parts of the image

    for j in range(height_shape):
        horizontal = []
        for i in range(width_shape):
            horizontal.append(crop_precise(img, i*input_width, j*input_height, input_width, input_height))
        sub_images.append(horizontal)

    return sub_images


def crop_precise(img, coord_x, coord_y, width_length, height_length):
    """
    To crop a precise portion of an image.
    When trying to crop outside of the boundaries, the input to padded with zeros.
    :param img: image to crop
    :param coord_x: width coordinate (top left point)
    :param coord_y: height coordinate (top left point)
    :param width_length: width of the cropped portion starting from coord_x
    :param height_length: height of the cropped portion starting from coord_y
    :return: the cropped part of the image
    """

    tmp_img = img[coord_y:coord_y + height_length, coord_x:coord_x + width_length]

    return float_im(tmp_img)  # From [0,255] to [0.,1.]


# from  https://stackoverflow.com/a/17511341/9768291
def ceildiv(a, b):
    return -(-a // b)


# adapted from  https://stackoverflow.com/a/52733370/9768291
def reconstruct(predictions, crops):

    # unflatten predictions
    def nest(data, template):
        data = iter(data)
        return [[next(data) for _ in row] for row in template]

    if len(crops) != 0:
        predictions = nest(predictions, crops)

    H = np.cumsum([x[0].shape[0] for x in predictions])
    W = np.cumsum([x.shape[1] for x in predictions[0]])
    D = predictions[0][0]
    recon = np.empty((H[-1], W[-1], D.shape[2]), D.dtype)
    for rd, rs in zip(np.split(recon, H[:-1], 0), predictions):
        for d, s in zip(np.split(rd, W[:-1], 1), rs):
            d[...] = s
    return recon


if __name__ == '__main__':
    print("   -  ", args)

    preds, original, crops = predict(args)  # returns the predictions along with the original
    enhanced = reconstruct(preds, crops)    # reconstructs the enhanced image from predictions

    plt.imsave('output/' + args.save, enhanced, cmap=plt.cm.gray)

    show_pred_output(original, enhanced)

问题(我想要什么)

有很多明显的幼稚方法可以解决这个问题,但我相信一定有一种非常简洁的方法:如何添加一个overlap_amount 变量,它可以让我做出重叠的预测,从而丢弃每个子图像的“边缘部分”(“片段”)并将其替换为对其周围片段的预测结果(因为它们不包含“边缘预测”)?

当然,我希望尽量减少“无用”预测(要丢弃的像素)的数量。还可能值得注意的是,输入段生成的输出段大 4 倍(即,如果它是 20x20 像素的图像,您现在会得到 80x80 像素的图像作为输出)。

【问题讨论】:

  • 为什么要将 image 分割成不同的部分?那么每个部分都可以在另一个线程/进程上处理吗?也许工作量应该在网络部分进行拆分。
  • @EranW 试图通过神经网络传递整个图像以获得在我的计算机 GPU 上计算的预测结果给我一个OOM(内存不足)错误,这就是我需要将图像拆分为单独的部分,然后使用 CPU 将它们正确地合并在一起。
  • 我将从重叠方法开始(在行和列中)并尝试找到一个尽可能小的值以减少额外的推断。您仍然需要弄清楚如何混合重叠的预测(例如均值或最大值运算符)

标签: python-3.x image image-processing computer-vision conv-neural-network


【解决方案1】:

我通过将推理转移到 CPU 中解决了一个类似的问题。它要慢得多,但至少在我的情况下解决了补丁边界问题,而不是我也测试过的基于重叠 ROI 投票或丢弃的方法。

假设您使用的是 Tensorflow 后端:

from tensorflow.python import device

with device('cpu:0')
    prediction = model.predict(...)

当然,前提是您有足够的 RAM 来适应您的模型。如果不是这样,请在下面评论,我会检查我的代码中是否有可以在这里使用的东西。

【讨论】:

  • 有趣的是,我从来没有想过这是一个解决方案,但它确实是合法的。但是,我仍然希望获得面向 GPU 的解决方案。
  • 您是否碰巧拥有用于不同测试的代码导致您选择此解决方案?
  • 告诉你一些消息:尝试 CPU 的事情基本上冻结了我的电脑,我不得不重新启动它。
  • 呃,对不起,我完全忘记了这件事。您是否在崩溃前检查过它是否填满了整个 RAM?
  • 我正在打开任务管理器时,一切都冻结了。我认为假设发生了这种情况是相当安全的,而且我并不想强迫我的计算机进入另一种我必须手动重新启动它的情况。无论如何,我刚回到这个项目,并想尝试这个懒惰的解决方案,但实际上我真的想要一个分段的解决方案(我刚刚开始着手实现天真的实现)。
【解决方案2】:

通过一种天真的方法解决了这个问题。它可能会很多更好,但至少这是可行的。

过程

基本上,它获取初始图像,然后在其周围添加填充,然后将其裁剪为多个子图像,这些子图像都排列成一个数组。裁剪完成后,所有图像也与周围的相邻图像重叠。

然后,将每张图像输入网络并收集预测结果(在这种情况下,基本上是图像分辨率的 4 倍)。在重建图像时,每个预测都是单独进行的,并且它的边缘被裁剪掉(因为它包含错误)。裁剪完成后,所有预测的粘合最终没有重叠,只有来自神经网络的预测的中间部分粘在一起。

最后,周围的填充被移除。

结果

不用排队了! :D

代码

import numpy as np
import matplotlib.pyplot as plt
import skimage.io

from keras.models import load_model

from constants import verbosity, save_dir, overlap, \
    model_name, tests_path, input_width, input_height, scale_fact
from utils import float_im


def predict(args):
    """
    Super-resolution on the input image using the model.

    :param args:
    :return:
        'predictions' contains an array of every single cropped sub-image once enhanced (the outputs of the model).
        'image' is the original image, untouched.
        'crops' is the array of every single cropped sub-image that will be used as input to the model.
    """
    model = load_model(save_dir + '/' + args.model)

    image = skimage.io.imread(tests_path + args.image)[:, :, :3]  # removing possible extra channels (Alpha)
    print("Image shape:", image.shape)

    predictions = []
    images = []

    # Padding and cropping the image
    overlap_pad = (overlap, overlap)  # padding tuple
    pad_width = (overlap_pad, overlap_pad, (0, 0))  # assumes color channel as last
    padded_image = np.pad(image, pad_width, 'constant')  # padding the border
    crops = seq_crop(padded_image)  # crops into multiple sub-parts the image based on 'input_' constants

    # Arranging the divided image into a single-dimension array of sub-images
    for i in range(len(crops)):         # amount of vertical crops
        for j in range(len(crops[0])):  # amount of horizontal crops
            current_image = crops[i][j]
            images.append(current_image)

    print("Moving on to predictions. Amount:", len(images))
    upscaled_overlap = overlap * 2
    for p in range(len(images)):
        if p % 3 == 0 and verbosity == 2:
            print("--prediction #", p)

        # Hack due to some GPUs that can only handle one image at a time
        input_img = (np.expand_dims(images[p], 0))  # Add the image to a batch where it's the only member
        pred = model.predict(input_img)[0]          # returns a list of lists, one for each image in the batch

        # Cropping the useless parts of the overlapped predictions (to prevent the repeated erroneous edge-prediction)
        pred = pred[upscaled_overlap:pred.shape[0]-upscaled_overlap, upscaled_overlap:pred.shape[1]-upscaled_overlap]

        predictions.append(pred)
    return predictions, image, crops


def show_pred_output(input, pred):
    plt.figure(figsize=(20, 20))
    plt.suptitle("Results")

    plt.subplot(1, 2, 1)
    plt.title("Input : " + str(input.shape[1]) + "x" + str(input.shape[0]))
    plt.imshow(input, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)

    plt.subplot(1, 2, 2)
    plt.title("Output : " + str(pred.shape[1]) + "x" + str(pred.shape[0]))
    plt.imshow(pred, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)

    plt.show()


# adapted from  https://stackoverflow.com/a/52463034/9768291
def seq_crop(img):
    """
    To crop the whole image in a list of sub-images of the same size.
    Size comes from "input_" variables in the 'constants' (Evaluation).
    Padding with 0 the Bottom and Right image.

    :param img: input image
    :return: list of sub-images with defined size (as per 'constants')
    """
    sub_images = []  # will contain all the cropped sub-parts of the image
    j, shifted_height = 0, 0
    while shifted_height < (img.shape[0] - input_height):
        horizontal = []
        shifted_height = j * (input_height - overlap)
        i, shifted_width = 0, 0
        while shifted_width < (img.shape[1] - input_width):
            shifted_width = i * (input_width - overlap)
            horizontal.append(crop_precise(img,
                                           shifted_width,
                                           shifted_height,
                                           input_width,
                                           input_height))
            i += 1
        sub_images.append(horizontal)
        j += 1

    return sub_images


def crop_precise(img, coord_x, coord_y, width_length, height_length):
    """
    To crop a precise portion of an image.
    When trying to crop outside of the boundaries, the input to padded with zeros.

    :param img: image to crop
    :param coord_x: width coordinate (top left point)
    :param coord_y: height coordinate (top left point)
    :param width_length: width of the cropped portion starting from coord_x (toward right)
    :param height_length: height of the cropped portion starting from coord_y (toward bottom)
    :return: the cropped part of the image
    """
    tmp_img = img[coord_y:coord_y + height_length, coord_x:coord_x + width_length]
    return float_im(tmp_img)  # From [0,255] to [0.,1.]


# adapted from  https://stackoverflow.com/a/52733370/9768291
def reconstruct(predictions, crops):
    """
    Used to reconstruct a whole image from an array of mini-predictions.
    The image had to be split in sub-images because the GPU's memory
    couldn't handle the prediction on a whole image.

    :param predictions: an array of upsampled images, from left to right, top to bottom.
    :param crops: 2D array of the cropped images
    :return: the reconstructed image as a whole
    """

    # unflatten predictions
    def nest(data, template):
        data = iter(data)
        return [[next(data) for _ in row] for row in template]

    if len(crops) != 0:
        predictions = nest(predictions, crops)

    # At this point "predictions" is a 3D image of the individual outputs
    H = np.cumsum([x[0].shape[0] for x in predictions])
    W = np.cumsum([x.shape[1] for x in predictions[0]])
    D = predictions[0][0]
    recon = np.empty((H[-1], W[-1], D.shape[2]), D.dtype)
    for rd, rs in zip(np.split(recon, H[:-1], 0), predictions):
        for d, s in zip(np.split(rd, W[:-1], 1), rs):
            d[...] = s

    # Removing the pad from the reconstruction
    tmp_overlap = overlap * (scale_fact - 1)  # using "-2" leaves the outer edge-prediction error
    return recon[tmp_overlap:recon.shape[0]-tmp_overlap, tmp_overlap:recon.shape[1]-tmp_overlap]


if __name__ == '__main__':
    print("   -  ", args)

    preds, original, crops = predict(args)  # returns the predictions along with the original
    enhanced = reconstruct(preds, crops)    # reconstructs the enhanced image from predictions

    # Save and display the result
    plt.imsave('output/' + args.save, enhanced, cmap=plt.cm.gray)
    show_pred_output(original, enhanced)

常量和额外位

verbosity = 2

input_width = 64

input_height = 64

overlap = 16

scale_fact = 4

def float_im(img):
    return np.divide(img, 255.)

另类

possibly better alternative 如果您遇到与我相同的问题,您可能需要考虑它;这是相同的基本理念,但更加完善和完善。

【讨论】:

    猜你喜欢
    • 2014-04-11
    • 1970-01-01
    • 2020-12-05
    • 2019-01-24
    • 2018-10-15
    • 2021-08-13
    • 1970-01-01
    • 2011-01-28
    • 2022-10-17
    相关资源
    最近更新 更多