【问题标题】:Custom Keras Data Generator with yield具有产量的自定义 Keras 数据生成器
【发布时间】:2019-09-28 11:00:45
【问题描述】:

我正在尝试创建一个自定义数据生成器,但不知道如何将yield 函数与__getitem__ 方法内的无限循环相结合。

编辑:回答后我意识到我使用的代码是Sequence,不需要yield 语句。

目前我使用return 语句返回多个图像:

class DataGenerator(tensorflow.keras.utils.Sequence):
    def __init__(self, files, labels, batch_size=32, shuffle=True, random_state=42):
        'Initialization'
        self.files = files
        self.labels = labels
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.random_state = random_state
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.files) / self.batch_size))

    def __getitem__(self, index):
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]

        files_batch = [self.files[k] for k in indexes]
        y = [self.labels[k] for k in indexes]

        # Generate data
        x = self.__data_generation(files_batch)

        return x, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.files))
        if self.shuffle == True:
            np.random.seed(self.random_state)
            np.random.shuffle(self.indexes)


    def __data_generation(self, files):
        imgs = []

        for img_file in files:

            img = cv2.imread(img_file, -1)

            ###############
            # Augment image
            ###############

            imgs.append(img) 

        return imgs

在这个article 中,我看到yield 用于无限循环。我不太明白这种语法。循环是如何逃脱的?

【问题讨论】:

标签: python tensorflow keras yield


【解决方案1】:

您正在使用序列 API,它的工作方式与普通生成器有点不同。在生成器函数中,您将使用 yield 关键字在 while True: 循环内执行迭代,因此每次 Keras 调用生成器时,它都会获取一批数据并自动环绕数据的末尾。

但是在序列中,__getitem__ 函数有一个index 参数,因此不需要迭代或yield,这是由 Keras 为您执行的。这样做是为了让序列可以使用多处理并行运行,这对于旧的生成器函数是不可能的。

所以你做事是正确的,不需要改变。

【讨论】:

【解决方案2】:

Keras中的生成器示例:

def datagenerator(images, labels, batchsize, mode="train"):
    while True:
        start = 0
        end = batchsize

        while start  < len(images): 
            # load your images from numpy arrays or read from directory
            x = images[start:end] 
            y = labels[start:end]
            yield x, y

            start += batchsize
            end += batchsize

Keras 希望您在生成器中运行无限循环。

如果你想了解 Python 生成器,那么 cmets 中的链接实际上是一个很好的起点。

【讨论】:

  • 但是我怎样才能把它包含在我的课堂上呢?因为我正在使用 __get_item
  • 你不需要__get_item
  • @Anakin 这不是真的,一个序列需要一个 getitem
  • 我的意思是你不需要序列,你也可以使用数据生成器。该问题甚至提到了 Keras 数据生成器。
  • @Anakin 注意keras 不再需要使用.fit_generator 方法。您可以在.fit 方法中直接传递上面提到的生成器。有关详细信息,请参阅the documentation
猜你喜欢
  • 2018-10-23
  • 2020-12-10
  • 2019-03-16
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2019-01-29
  • 2018-07-12
  • 2017-07-07
相关资源
最近更新 更多