【发布时间】:2020-02-23 19:10:29
【问题描述】:
出于测试目的,我想为我的 tf.dataset 中的每个样本附加一个 ID。简单地向上计数就足够了。
我的数据集是 FlatMapDataset fwiw 类型。
for entry in img_ds:
print(entry.shape)
(128, 128, 3)
(128, 128, 3)
(128, 128, 3)
(128, 128, 3)
...
我尝试的是有一个映射函数,它在其中定义一个计数器并向上计数:
@staticmethod
def map_to_id(img):
try:
ExperimentalPipeline.map_to_id.id_counter += 1
except AttributeError:
ExperimentalPipeline.map_to_id.id_counter = 0
return img, ExperimentalPipeline.map_to_id.id_counter
然后使用来自 tf.data 的Dataset.map 为每个样本附加一个 id:
img_ds = img_ds.map(ExperimentalPipeline.map_to_id)
不幸的是,这不起作用,每个样本的 id 为零:
for i, id in img_ds:
print(f"{i.shape}, {id}")
(128, 128, 3), 0
(128, 128, 3), 0
(128, 128, 3), 0
(128, 128, 3), 0
...
我还注意到我的map_to_id 函数只被调用了一次。
@staticmethod
def map_to_id(img):
print("enter map_to_id")
try:
ExperimentalPipeline.map_to_id.id_counter += 1
except AttributeError:
print("caught exception")
ExperimentalPipeline.map_to_id.id_counter = np.random.randint(1000)
return img, ExperimentalPipeline.map_to_id.id_counter
输入 map_to_id
捕获异常
(128, 128, 3), 889
(128, 128, 3), 889
(128, 128, 3), 889
(128, 128, 3), 889
我想我不明白Dataset.map 应该如何工作。我虽然会获取数据集中被调用的每个样本,并以样本作为参数调用提供的函数。
有人可以帮我解决这个问题吗?
【问题讨论】:
标签: python tensorflow tensorflow2.0 tensorflow-datasets