【问题标题】:Keras: fit_generator NotImplementedErrorKeras:fit_generator NotImplementedError
【发布时间】:2020-06-03 17:23:41
【问题描述】:

我在使用 model.fit_generator 时遇到问题,它抛出 NotImplementedError,我不知道问题出在哪里。在较旧的 Keras 和 TF 下它可以工作,但多年后我尝试更新到新版本并出现问题。

当我使用时:

model.fit_generator(generator=generator_train,
                        steps_per_epoch=generator_train.n / batch_size,
                        epochs=20,
                        verbose=1,
                        validation_data=generator_val,
                        validation_steps=math.ceil(generator_val.n / batch_size),
                        callbacks=[tb_callback, saver_callback],
                        use_multiprocessing=False,
                        initial_epoch=0
                        )

I got this error

我的发电机:

import cv2
import numpy as np
from keras.preprocessing.image import Iterator
from boxcars_image_transformations import alter_HSV, image_drop, unpack_3DBB, add_bb_noise_flip
import random

#%%

class BoxCarsDataGenerator(Iterator):
    def __init__(self, dataset, part, batch_size=8, training_mode=False, seed=None, generate_y = True, image_size = (224,224)):
        assert image_size == (224,224), "only images 224x224 are supported by unpack_3DBB for now, if necessary it can be changed"
        assert dataset.X[part] is not None, "load some classification split first"
        super().__init__(dataset.X[part].shape[0], batch_size, training_mode, seed)
        self.part = part
        self.generate_y = generate_y
        self.dataset = dataset
        self.image_size = image_size
        self.training_mode = training_mode
        if self.dataset.atlas is None:
            self.dataset.load_atlas()
        print("ANOOO TU SOM")

    #%%
    def __next__(self):
        with self.lock:
            index_array, current_index, current_batch_size = next(self.index_generator)
        x = np.empty([current_batch_size] + list(self.image_size) + [3], dtype=np.float32)
        for i, ind in enumerate(index_array):
            vehicle_id, instance_id = self.dataset.X[self.part][ind]
            vehicle, instance, bb3d = self.dataset.get_vehicle_instance_data(vehicle_id, instance_id)
            image = self.dataset.get_image(vehicle_id, instance_id) 
            if self.training_mode:
                image = alter_HSV(image) # randomly alternate color
                image = image_drop(image) # randomly remove part of the image
                bb_noise = np.clip(np.random.randn(2) * 1.5, -5, 5) # generate random bounding box movement
                flip = bool(random.getrandbits(1)) # random flip
                image, bb3d = add_bb_noise_flip(image, bb3d, flip, bb_noise) 

            image = unpack_3DBB(image, bb3d)      
            image = (image.astype(np.float32) - 116)/128.
            x[i, ...] = image
        if not self.generate_y:
            return x
        y = self.dataset.Y[self.part][index_array]
        return x, y


【问题讨论】:

标签: python tensorflow keras generator


【解决方案1】:

编辑 2:如何解决您的问题:

你想添加:

def __getitem__(self, item):
    return self.__next__()

在你的BoxCarsDataGenerator

编辑:在更详细地查看问题后,问题来自 Keras 的迭代器类;当迭代你的BoxCarsDataGenerator__getitem__ 时返回:

self._get_batches_of_transformed_samples(index_array)

_get_batches_of_transformed_samples定义为:

def _get_batches_of_transformed_samples(self, index_array):
    """Gets a batch of transformed samples.

    # Arguments
        index_array: Array of sample indices to include in batch.

    # Returns
        A batch of transformed samples.
    """
    raise NotImplementedError

总之,你不应该使用 Keras 的 Iterator,因为它似乎还没有实现。

编辑结束

fit_generator 现在已被弃用,因为 fit 也可以处理生成器,您可以查看这篇文章:https://stackoverflow.com/a/59381897/12896974

【讨论】:

  • 我也用model.fit()试过了,但错误是一样的。
  • 您可以添加整个跟踪日志吗?
  • 好吧,我的整个跟踪日志与错误图像相同。
  • 是的,但对我来说问题出在哪里仍然没有意义。
  • 老实说,我查看了 github 对这个文件的历史,似乎这个功能从未实现过,你应该转向另一种方法来解决你的问题。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 2022-08-04
  • 2021-03-15
  • 1970-01-01
  • 2019-02-13
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多