【问题标题】:Shuffling input files with tensorflow Datasets使用 tensorflow 数据集改组输入文件
【发布时间】:2018-05-18 21:48:01
【问题描述】:

使用旧的输入管道 API 我可以做到:

filename_queue = tf.train.string_input_producer(filenames, shuffle=True)

然后将文件名传递给其他队列,例如:

reader = tf.TFRecordReader()
_, serialized_example = reader.read_up_to(filename_queue, n)

如何使用 Dataset -API 实现类似的行为?

tf.data.TFRecordDataset() 期望文件名的张量按固定顺序排列。

【问题讨论】:

标签: python tensorflow dataset


【解决方案1】:

按顺序开始阅读,shuffle 紧随其后:

BUFFER_SIZE = 1000 # arbitrary number
# define filenames somewhere, e.g. via glob
dataset = tf.data.TFRecordDataset(filenames).shuffle(BUFFER_SIZE)

编辑:

this question 的输入管道让我了解了如何使用 Dataset API 实现文件名改组:

dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.shuffle(BUFFER_SIZE) # doesn't need to be big
dataset = dataset.flat_map(tf.data.TFRecordDataset)
dataset = dataset.map(decode_example, num_parallel_calls=5) # add your decoding logic here
# further processing of the dataset

这会将一个文件的所有数据放在下一个文件之前,依此类推。文件被打乱,但其中的数据将以相同的顺序生成。 您也可以将dataset.flat_map 替换为interleave 以同时处理多个文件并从每个文件返回样本:

dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=4)

注意: interleave 实际上并没有在多个线程中运行,它是一个循环操作。真正的并行处理见parallel_interleave

【讨论】:

  • 这会打乱我的文件还是文件中的数据?
  • 好的,但是当您有一个包含相同标签(用于深度学习)的一长串 TFRecord 文件(总共超过 50000 个示例),然后是另一个系列文件时,您会怎么做包含带有另一个标签的示例。为了使洗牌工作,您需要一个大于 50000 的缓冲区,因此需要大量的 RAM。这不是解决方案。改组文件名是一种更简单的解决方案。
  • 我并不是建议您将所有内容打包在一个大文件中,您的用例对我来说似乎很合理。我要指出的问题是,如果您只打乱文件名,您仍然会以相同的顺序读取每个文件中的数据。我同意洗牌也没有什么坏处,但是在解码样本后你仍然需要一个带有缓冲区的shuffle(),除非你可以让它们始终以相同的顺序排列。
  • @Pekka 我认为编辑可能是您的目标
  • 很高兴它很有帮助,请记住,这段代码是为 TF 1.4 编写的(我认为,或者接近那个),从那时起数据集 API 发生了巨大的变化,所以有些东西可以在今天更有效的方式:)
【解决方案2】:

当前的 Tensorflow 版本(02/2018 中的 v1.5)似乎不支持在 Dataset API 中本地进行文件名改组。这是一个使用 numpy 的简单解决方法:

import numpy as np
import tensorflow as tf

myShuffledFileList = np.random.choice(myInputFileList, size=len(myInputFileList), replace=False).tolist()

dataset = tf.data.TFRecordDataset(myShuffledFileList)

【讨论】:

  • 动态加载文件列表:tf.data.Dataset.list_files('pattern-here').shuffle(BUFFER_SIZE)。对其进行硬编码:tf.data.Dataset.from_tensor_slices([filenames]).shuffle(BUFFER_SIZE)。两者都必须后跟适当的.map,并带有打开和读取文件中记录的解码功能。同样,当前的 API 怎么可能做到这一点?另外,如果你真的想使用numpynp.random.shuffle(myInputFileList)
猜你喜欢
  • 2018-08-17
  • 2020-12-07
  • 1970-01-01
  • 2018-08-16
  • 1970-01-01
  • 2022-01-15
  • 1970-01-01
  • 1970-01-01
  • 2018-07-31
相关资源
最近更新 更多