【问题标题】:tensorflow_dataset image transform with dataset.maptensorflow 数据集图像转换与 dataset.map
【发布时间】:2019-07-03 17:07:59
【问题描述】:

我正在尝试使用 Python 中的 tesorflow_dataset 库加载 cifar100 dataset。使用.load() 加载数据后,我正在尝试使用.map() 将图像转换为设定的大小,地图内的 lambda 给了我

TypeError: () 缺少 2 个必需的位置参数: “粗标签”和“标签”

运行我的代码时。

在将标签信息保留在数据中的同时转换这些图像的最佳方法是什么?我不太确定 lambda 函数如何与数据集交互。

这是通过 tensorflow 2.0.0b1、tensorflow-datasets 1.0.2 和 Python 3.7.3 完成的

def transform_images(x_train, size):
    x_train = tf.image.resize(x_train, (size, size))
    x_train = x_train / 255
    return x_train

train_dataset = tfds.load(name="cifar100", split=tfds.Split.TRAIN)
train_dataset = train_dataset.map(lambda image, coarse_label, label: 
        (dataset.transform_images(image, FLAGS.size), coarse_label, label))

【问题讨论】:

    标签: tensorflow tensorflow-datasets


    【解决方案1】:

    train_dataset 的每一行都是一个字典,而不是一个元组。所以你不能像lambda image, coarse_label, label那样使用lambda

    import tensorflow as tf
    import tensorflow_datasets as tfds
    
    train_dataset = tfds.load(name="cifar100", split=tfds.Split.TRAIN)
    print(train_dataset.output_shapes)
    
    # {'image': TensorShape([32, 32, 3]), 'label': TensorShape([]), 'coarse_label': TensorShape([])}
    

    你应该像下面这样使用它:

    def transform_images(row, size):
        x_train = tf.image.resize(row['image'], (size, size))
        x_train = x_train  / 255
        return x_train, row['coarse_label'], row['label']
    
    train_dataset = train_dataset.map(lambda row:transform_images(row, 16))
    print(train_dataset.output_shapes)
    
    # (TensorShape([16, 16, 3]), TensorShape([]), TensorShape([]))
    

    【讨论】:

    • 有效!先生,您是救命稻草。那么为什么我可以访问函数内部的row['image'],但如果我尝试执行train_dataset['image'],我会得到:TypeError: '_OptionsDataset' object is not subscriptable
    • @KennethWitham train_dataset 的每一行都是字典,而不是 train_dataset 是字典。
    猜你喜欢
    • 2020-09-23
    • 2017-03-09
    • 2022-07-11
    • 2020-05-31
    • 1970-01-01
    • 2016-11-13
    • 2018-07-21
    • 2018-07-22
    • 1970-01-01
    相关资源
    最近更新 更多