【问题标题】:Mapping each sample in a tf.dataset to an id将 tf.dataset 中的每个样本映射到一个 id
【发布时间】:2020-02-23 19:10:29
【问题描述】:

出于测试目的,我想为我的 tf.dataset 中的每个样本附加一个 ID。简单地向上计数就足够了。

我的数据集是 FlatMapDataset fwiw 类型。

for entry in img_ds:
        print(entry.shape)

(128, 128, 3)
(128, 128, 3)
(128, 128, 3)
(128, 128, 3)
...

我尝试的是有一个映射函数,它在其中定义一个计数器并向上计数:

@staticmethod
    def map_to_id(img):
        try:
            ExperimentalPipeline.map_to_id.id_counter += 1
        except AttributeError:
            ExperimentalPipeline.map_to_id.id_counter = 0
        return img, ExperimentalPipeline.map_to_id.id_counter

然后使用来自 tf.data 的Dataset.map 为每个样本附加一个 id:

img_ds = img_ds.map(ExperimentalPipeline.map_to_id)

不幸的是,这不起作用,每个样本的 id 为零:

for i, id in img_ds:
        print(f"{i.shape}, {id}")

(128, 128, 3), 0
(128, 128, 3), 0
(128, 128, 3), 0
(128, 128, 3), 0
...

我还注意到我的map_to_id 函数只被调用了一次。

@staticmethod
def map_to_id(img):
    print("enter map_to_id")
    try:
        ExperimentalPipeline.map_to_id.id_counter += 1
    except AttributeError:
        print("caught exception")
        ExperimentalPipeline.map_to_id.id_counter = np.random.randint(1000)
    return img, ExperimentalPipeline.map_to_id.id_counter

输入 map_to_id
捕获异常
(128, 128, 3), 889
(128, 128, 3), 889
(128, 128, 3), 889
(128, 128, 3), 889

我想我不明白Dataset.map 应该如何工作。我虽然会获取数据集中被调用的每个样本,并以样本作为参数调用提供的函数。
有人可以帮我解决这个问题吗?

【问题讨论】:

    标签: python tensorflow tensorflow2.0 tensorflow-datasets


    【解决方案1】:

    TensorFlow 将运行一次map 函数以将该函数编译为 TensorFlow 操作。然后这些操作,而不是原始的 python 函数,将应用于数据集的每个元素。如果你想为每个元素运行你原来的python函数,你可以使用py_function来代替。

    在您想要附加元素 ID 的特定情况下,您可以使用 Dataset.enumerate 来实现您的目标:

    img_ds = img_ds.enumerate()
    

    【讨论】:

    • 啊不错Dataset.enumerate 看起来很方便,明天去研究一下。谢谢!
    【解决方案2】:

    好的,在阅读了 tensorflow 文档之后,我发现了这个:

    请注意,无论定义 map_func 的上下文如何 (eager vs. graph),tf.data 跟踪函数并将其作为 图形。要在函数内部使用 Python 代码,您有两种选择:

    1) 依靠 AutoGraph 将 Python 代码转换为等价图 计算。这种方法的缺点是 AutoGraph 可以 转换部分但不是全部 Python 代码。

    2) 使用 tf.py_function,它允许你编写任意 Python 代码 但通常会导致比 1) 更差的性能

    所以map_to_id 函数确实只被跟踪了一次。
    由于选项 1) 似乎不起作用,我只选择选项 2)。我只需要一些单元测试的 id,所以性能应该不是问题。

    解决方案如下:

    img_ds = img_ds.map(
        lambda img: tf.py_function(
            func=ExperimentalPipeline.map_to_id, inp=[img], Tout=(tf.float32, tf.int32)
        )
    )
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-09-01
      • 1970-01-01
      • 1970-01-01
      • 2017-08-10
      • 1970-01-01
      • 1970-01-01
      • 2021-11-24
      • 2022-01-01
      相关资源
      最近更新 更多