【问题标题】:Why is my tensorflow generator function so slow?为什么我的张量流生成器功能这么慢?
【发布时间】:2023-04-04 20:17:01
【问题描述】:

下面的生成器函数太慢了。有没有办法优化这段代码? train_dataset_c1 是图像类型 1 的训练数据集,1 train_dataset_c0 是图像类型 0 的训练数据集,0

def generator(positive_dataset, negative_dataset):
while True:
    for pos_rec, neg_rec in zip(positive_dataset, negative_dataset):
        pos_x, pos_y = pos_rec
        neg_x, neg_y = neg_rec
        x = tf.concat([pos_x, neg_x], axis=0)
        y = tf.concat([pos_y, neg_y], axis=0)
        yield x, y

train_generator = generator(train_dataset_c1, train_dataset_c0)
test_generator = generator(test_dataset_c1, test_dataset_c0)

【问题讨论】:

  • 太慢是什么意思?
  • 嘿@MatiasValdenegro,当我尝试从 test_generator 打印前几张图像时,它需要相当长的时间。我也很好奇是什么让 test_generator 比 train_generator 更慢。顺便说一句,我使用 dataset.skip(n) 来创建测试数据集和 dataset.take(n) 来创建火车。

标签: deep-learning generator tensorflow2.0


【解决方案1】:

如果您使用的是 tensorflow 2.0,我建议您使用 tf.data API 来加速您的管道。

实际上有一个from_generator 函数可以应用于你的生成器来加速它

使用此函数将其转换为 tf.data.Dataset 对象后,您可以使用此tutorial 中的任何策略对其进行进一步优化

【讨论】:

    猜你喜欢
    • 2020-10-08
    • 1970-01-01
    • 1970-01-01
    • 2014-09-15
    • 2022-07-29
    • 2018-11-16
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多