好的,我有一个类似的问题(虽然我不是在处理医学图像)并找到了解决方案,所以我希望其他人也会觉得它有用。
- 我假设您需要一个自定义函数来批量检索图像,因为它们不会一次全部放入内存中,因为不支持 *.hdr 文件格式,并且因为现有的 keras 辅助函数不支持不要使用回归。 (我猜你正在做某种类型的分割,如果你使用的是 u-net?)
- 我还假设您需要ImageDataGenerator,因为您不想自己实现数据增强。
因此,由于 1) 您需要将 fit_generator 函数与 IDG 结合使用,唯一的问题是 ImageDataGenerator (IDG) 不支持自定义生成器。
在某些情况下,您可以将 IDG 与 fit_generator 函数一起使用:IDG flow 函数返回 NumpyArrayIterator 类型的 Iterator。你不能使用这个,因为它需要数据适合工作内存。
IDG.flow 函数的使用/工作方式是,您首先创建 IDG 对象的实例,然后调用创建并返回 NumpyArrayIterator 的流函数,该函数包含对 IDG 对象的引用。
现在一种解决方案是编写您的自定义 DataGenerator,它继承自 keras.preprocessing.image.Iterator 类并实现 _get_batches_of_transformed_samples(请参阅 here)。
然后您扩展 IDG 类并编写一个 flow_from_generator 函数,该函数返回您的自定义 DataGenerator 的一个实例。
这听起来比实际上更费力,但一定要熟悉 IDG、NumpyArrayIterator 和 Iterator 代码。
下面是这个样子:
class DataGenerator(keras.preprocessing.image.Iterator):
def__init__(self, image_data_generator, *args, **kwargs):
#init whatever you need
self.image_data_generator = image_data_generator
#call Iterator constructor:
super(DataGenerator, self).__init__(number_of_datapoints, batch_size, shuffle, shuffle_seed)
def _get_batches_of_transformed_samples(self, index_array):
''' Here you retrieve the images and apply the image augmentation,
then return the augmented image batch.
index_array is just a list that takes care of the shuffling for you (see super class),
so this function is going to be called with index_array=[1, 6, 8]
if your batch size is 3
'''
x_transformed = np.zeros((batch_size, x_img_size, y_img_size, input_channel_num), dtype_float32)
y_transformed = np.zeros((batch_size, x_img_size, y_img_size, output_channel_num), dtype_float32)
for i, j in enumerate(index_array):
x = get_input_image_from_index(j)
y = get_output_image_from_index(j)
params = self.image_data_generator.get_random_transform(self.img_shape)
x = self.image_data_generator.apply_transform(x, params)
x = self.image_data_generator.standardize(x)
x_transformed[i] = x
y = self.image_data_generator.apply_transform(y, params)
y = self.image_data_generator.standardize(y)
y_transformed[i] = y
return(x_transformed, y_transformed)
class ImageDataGeneratorExtended(keras.preprocessing.image.ImageDataGenerator):
def flow_from_generator:(self, *args, **kwargs):
return DataGenerator(self, *args, **kwargs)
好的,希望对您有所帮助。我已经使用了我自己的上述代码版本,但还没有完全测试过它(尽管它现在对我有用),所以请谨慎对待:P
对于 *.hdr 问题:看来您可以使用ImageIO
包(它supports HDR 和 DICOM 格式,虽然我从未亲自使用过那个库)。