【问题标题】:tf.data pipeline design for optimized performance优化性能的 tf.data 管道设计
【发布时间】:2018-11-08 13:42:25
【问题描述】:

我是 TensorFlow 新手,我想知道使用 tfdata 设置数据集的特定顺序。例如:

    data_files = tf.gfile.Glob("%s%s%s" % ("./data/cifar-100-binary/", self.data_key, ".bin"))
    data = tf.data.FixedLengthRecordDataset(data_files, record_bytes=3074)
    data = data.map(self.load_transform)
    if self.shuffle_key:
        data = data.shuffle(5000)

    data = data.batch(self.batch_size).repeat(100)
    iterator = data.make_one_shot_iterator()
    img, label = iterator.get_next()
    # label = tf.one_hot(label, depth=100)
    print('img_shape:', img.shape)

在这种情况下,我读取数据,然后对数据进行洗牌,然后是批量和重复规范。使用这种方法,我的电脑内存增加了 2%

然后我又尝试了一种方法:

    data_files = tf.gfile.Glob("%s%s%s" % ("./data/cifar-100-binary/", self.data_key, ".bin"))
    data = tf.data.FixedLengthRecordDataset(data_files, record_bytes=3074)
    data = data.map(self.load_transform)
    data = data.batch(self.batch_size).repeat(100)
    if self.shuffle_key:
        data = data.shuffle(5000)
    iterator = data.make_one_shot_iterator()
    img, label = iterator.get_next()
    # label = tf.one_hot(label, depth=100)
    print('img_shape:', img.shape)

所以在这种情况下,当我第一次指定批量大小时,重复然后随机播放 RAM 利用率会增加 40%(我不知道为什么),如果有人能帮我解决这个问题,那就太好了。 那么是否有一个我应该始终遵循的顺序来使用 tf.data 在 tensorflow 中定义数据集?

【问题讨论】:

标签: python tensorflow deep-learning tensorflow-datasets


【解决方案1】:

内存使用量增加,因为您正在改组批次而不是单个记录。

data.shuffle(5000) 将填充5000 元素的缓冲区,然后从缓冲区中随机采样以生成下一个元素。

data.batch(self.batch_size) 将元素类型从单个记录更改为批量记录。因此,如果您在shuffle 之前调用batch,则随机播放缓冲区将包含5000 * self.batch_size 记录,而不仅仅是5000

调用shufflebatch的顺序也会影响数据本身。在洗牌之前进行批处理将导致批处理的所有元素都是连续的。

batchshuffle 之前:

>>> dataset = tf.data.Dataset.range(12)
>>> dataset = dataset.batch(3)
>>> dataset = dataset.shuffle(4)
>>> print([element.numpy() for element in dataset])
[array([ 9, 10, 11]), array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8])]

shufflebatch 之前:

>>> dataset = tf.data.Dataset.range(12)
>>> dataset = dataset.shuffle(4)
>>> dataset = dataset.batch(3)
>>> print([element.numpy() for element in dataset])
[array([1, 2, 5]), array([4, 7, 8]), array([0, 3, 9]), array([ 6, 10, 11])]

通常在批处理之前进行洗牌,以避免批处理中的所有元素都是连续的。

【讨论】:

    猜你喜欢
    • 2012-09-07
    • 2021-11-25
    • 1970-01-01
    • 2019-02-24
    • 2013-06-25
    • 2020-10-08
    • 1970-01-01
    • 1970-01-01
    • 2013-03-17
    相关资源
    最近更新 更多