【问题标题】:TensorFlow dataset .map() method not working for built-in tf.keras.preprocessing.image functionsTensorFlow 数据集 .map() 方法不适用于内置 tf.keras.preprocessing.image 函数
【发布时间】:2020-10-06 10:12:44
【问题描述】:

我这样加载数据集:

import tensorflow_datasets as tfds

ds = tfds.load(
    'caltech_birds2010',
    split='train',
    as_supervised=False)

这个功能很好用:

import tensorflow as tf

@tf.function
def pad(image,label):
    return (tf.image.resize_with_pad(image,32,32),label)

ds = ds.map(pad)

但是当我尝试映射不同的内置函数时

from tf.keras.preprocessing.image import random_rotation

@tf.function
def rotate(image,label):
    return (random_rotation(image,90), label)

ds = ds.map(rotate)

我收到以下错误:

AttributeError: 'Tensor' 对象没有属性 'ndim'

这不是唯一给我带来问题的函数,不管有没有 @tf.function 装饰器都会发生这种情况。

非常感谢任何帮助!

【问题讨论】:

    标签: python tensorflow keras tensorflow-datasets


    【解决方案1】:

    我会尝试在此处使用 tf.py_function 进行随机旋转。例如:

    def rotate(image, label):
        im_shape = image.shape
        [image, label,] = tf.py_function(random_rotate,[image, label],
                                         [tf.float32, tf.string])
        image.set_shape(im_shape)
        return image, label
    
    ds = ds.map(rotate)
    

    虽然我认为他们根据What is the difference in purpose between tf.py_function and tf.function? 在这里做了类似的事情,但 tf.py_function 通过 tensorflow 执行 python 代码更直接,即使 tf.function 具有性能优势。

    【讨论】: