【问题标题】:How to use tensorflow2.0 dataset with keras ImageDataGenerator如何在 keras ImageDataGenerator 中使用 tensorflow2.0 数据集
【发布时间】:2019-04-10 15:14:06
【问题描述】:

我正在使用 tensorflow 2.0 API,其中我从所有图像路径创建了一个数据集,如下例所示

X_train, X_test, y_train, y_test = train_test_split(all_image_paths, all_image_labels, test_size=0.20, random_state=32)

path_train_ds = tf.data.Dataset.from_tensor_slices(X_train)
image_train_ds = path_train_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)

但是,当我运行此代码以使用 keras ImageDataGenerator 应用一些参数时出现错误

datagen=tf.keras.preprocessing.image.ImageDataGenerator(featurewise_center=True,
        featurewise_std_normalization=True,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True)
datagen.fit(image_train_ds)

错误:

 /usr/local/lib/python3.6/dist-packages/keras_preprocessing/image/image_data_generator.py in fit(self, x, augment, rounds, seed)
    907             seed: Int (default: None). Random seed.
    908        """
--> 909         x = np.asarray(x, dtype=self.dtype)
    910         if x.ndim != 4:
    911             raise ValueError('Input to `.fit()` should have rank 4. '

/usr/local/lib/python3.6/dist-packages/numpy/core/numeric.py in asarray(a, dtype, order)
    499 
    500     """
--> 501     return array(a, dtype, copy=False, order=order)
    502 
    503 

TypeError: float() argument must be a string or a number, not 'ParallelMapDataset'

【问题讨论】:

    标签: keras tensorflow2.0


    【解决方案1】:

    tf.keras.preprocessing.image.ImageDataGenerator 不适用于tf.data.Dataset 对象,它被设计用于处理普通的旧图像。

    如果您想应用扩充,您必须使用tf.data.Dataset 对象本身(通过各种.map 调用),或者您可以在使用tf.keras.preprocessing.image.ImageDataGenerator 创建扩充数据集后创建tf.data.Dataset 对象。

    【讨论】:

    • 如果属实,这是一个令人难以置信的愚蠢设计选择。
    猜你喜欢
    • 2019-12-07
    • 2020-09-05
    • 2019-11-27
    • 1970-01-01
    • 2019-02-08
    • 2017-06-18
    • 1970-01-01
    • 1970-01-01
    • 2018-12-02
    相关资源
    最近更新 更多