【发布时间】:2023-04-04 20:17:01
【问题描述】:
下面的生成器函数太慢了。有没有办法优化这段代码? train_dataset_c1 是图像类型 1 的训练数据集,1 train_dataset_c0 是图像类型 0 的训练数据集,0
def generator(positive_dataset, negative_dataset):
while True:
for pos_rec, neg_rec in zip(positive_dataset, negative_dataset):
pos_x, pos_y = pos_rec
neg_x, neg_y = neg_rec
x = tf.concat([pos_x, neg_x], axis=0)
y = tf.concat([pos_y, neg_y], axis=0)
yield x, y
train_generator = generator(train_dataset_c1, train_dataset_c0)
test_generator = generator(test_dataset_c1, test_dataset_c0)
【问题讨论】:
-
太慢是什么意思?
-
嘿@MatiasValdenegro,当我尝试从 test_generator 打印前几张图像时,它需要相当长的时间。我也很好奇是什么让 test_generator 比 train_generator 更慢。顺便说一句,我使用 dataset.skip(n) 来创建测试数据集和 dataset.take(n) 来创建火车。
标签: deep-learning generator tensorflow2.0