您可以将ImageDataGenerator 类与您喜欢的任何类型的标签一起使用,也可以是图像,它们只是多维数组。这是一个使用虚拟 numpy 图像的示例:
from keras.preprocessing.image import ImageDataGenerator
import numpy as np
# Create fake images
n, width, height = 1000, 28, 28
images_data = np.random.randint(low=0, high=256, size=(n, height, width, 3))
images_labels = np.random.randint(low=0, high=256, size=(n, height, width, 3))
image_gen = ImageDataGenerator()
batch_size = 100
batch_gen = image_gen.flow(images_data, images_labels, batch_size=batch_size)
然后,您可以将batch_gen 传递给fit_generator,例如,它将产生(images_data, images_labels) 的元组,两者都具有(batch_size, height, width, 3) 的形状。您可以通过以下方式进行检查:
batch = batch_gen.next()
print(len(batch))
print(batch[0].shape)
print(batch[1].shape)
如果您的数据集不适合内存并存储为文件,您也可以使用flow_from_directory。 Keras 官方文档中有examples。
如果您编写自己的批处理生成器函数,您还可以在生成批处理之前对它们进行一些操作:
def _generate_batches(image_gen, images_data, images_labels, batch_size):
for batch in image_gen.flow(images_data, images_labels,
batch_size=batch_size):
# Here you can do whatever you like to your batch
yield (batch[0], batch[1])
最后,如果您需要 ImageDataGenerator 的特定功能,您可以随时构建您的自定义 ImageDataGenerator 类:
class ImageDataGeneratorCustom(ImageDataGenerator):
...
特别是,您可能想要覆盖flow() 函数,甚至构建自定义Iterator。