【问题标题】:Keras predict_generator corrupted imagesKeras predict_generator 损坏的图像
【发布时间】:2020-10-13 14:40:56
【问题描述】:

我正在尝试使用我训练有素的模型预测数百万张图像,该模型使用 python 3 中的 predict_generator 以 keras 和 tensorflow 作为后端。生成器和模型预测工作,但是,目录中的某些图像损坏或损坏并导致 predict_generator 停止并引发错误。删除图像后,它会再次工作,直到下一个损坏/损坏的图像通过该功能。

由于图像太多,运行脚本来打开每个图像并删除引发错误的图像是不可行的。有没有办法将“如果损坏则跳过图像”参数合并到生成器中或从目录函数中流出?
非常感谢任何帮助!

【问题讨论】:

  • This answer 是一种解决方法。
  • @today,您的答案可能适用于训练,但可能不适用于预测。在预测过程中,您需要将预测分数映射到原始图像文件名。因此,您需要跟踪哪些图像已损坏。我特意在下面给出答案。

标签: python keras


【解决方案1】:

ImageDataGeneratorflow_from_directory 方法中都没有这样的参数,正如您在 Keras 文档中看到的那样(herehere)。一种解决方法是扩展ImageDataGenerator 类并重载flow_from_directory 方法,以检查图像是否已损坏,然后再将其放入生成器中。 Here你可以找到它的源代码。

【讨论】:

  • 源代码链接失效。这里是更新的link:
【解决方案2】:

由于它发生在预测期间,如果您跳过任何图像或批次,您需要跟踪哪些图像被跳过,以便您可以正确地将预测分数映射到图像文件名。

基于这个想法,我的 DataGenerator 是用一个有效的图像索引跟踪器实现的。特别要关注变量valid_index,在该变量中跟踪有效图像的索引。

class DataGenerator(keras.utils.Sequence):
    def __init__(self, df, batch_size, verbose=False, **kwargs):
        self.verbose = verbose
        self.df = df
        self.batch_size = batch_size
        self.valid_index = kwargs['valid_index']
        self.success_count = self.total_count = 0

    def __len__(self):
        return int(np.ceil(self.df.shape[0] / float(self.batch_size)))

    def __getitem__(self, idx):
        print('generator is loading batch ',idx)
        batch_df = self.df.iloc[idx * self.batch_size:(idx + 1) * self.batch_size]
        self.total_count += batch_df.shape[0]

        # return a list whose element is either an image array (when image is valid) or None(when image is corrupted)
        x = load_batch_image_to_arrays(batch_df['image_file_names'])

        # filter out corrupted images
        tmp = [(u, i) for u, i in zip(x, batch_df.index.values.tolist()) if
               u is not None]

        # boundary case. # all image failed, return another random batch
        if len(tmp) == 0:
            print('[ERROR] All images loading failed')
            # based on https://github.com/keras-team/keras/blob/master/keras/utils/data_utils.py#L621,
            # Keras will automatically find the next batch if it returns None
            return None

        print('successfully loaded image in {}th batch {}/{}'.format(str(idx), len(tmp), self.batch_size))
        self.success_count += len(tmp)

        x, batch_index = zip(*tmp) 
        x = np.stack(x)  # list to np.array
        self.valid_index[idx] = batch_index

        # follow preprocess input function provided by keras
        x = resnet50_preprocess(np.array(x, dtype=np.float))
        return x

    def on_epoch_end(self):
        print('total image count', self.total_count)
        print('successful images count', self.success_count)
        self.success_count = self.total_count = 0 # reset count after one epoch ends.

在预测期间。

predictions = model.predict_generator(
            generator=data_gen,
            workers=10,
            use_multiprocessing=False,
            max_queue_size=20,
            verbose=1
        ).squeeze()
indexes = []
for i in sorted(data_gen.valid_index.keys()):
    indexes.extend(data_gen.valid_index[i])
result_df = df.loc[indexes]
result_df['score'] = predictions

【讨论】:

    猜你喜欢
    • 2019-02-15
    • 2012-11-18
    • 1970-01-01
    • 1970-01-01
    • 2012-04-13
    • 1970-01-01
    • 2020-12-05
    • 2020-12-19
    • 2020-02-11
    相关资源
    最近更新 更多