【发布时间】:2020-08-08 13:40:17
【问题描述】:
我使用一个大型图像数据集,我在第一步将其转换为 tfrecords,然后在下一步加载到 tf.data.dataset。
但是数据集太大了,尽管有 12 GB 的 GPU,但我无法获得比 10 更大的批量大小。现在问题来了,如何优化图像的加载,以便达到更大的 batch_size。
有没有办法使用 .fit_generator() 来优化这个过程?
这是我当前加载训练数据的过程(验证数据以相同的方式转换,因此这里也没有显示):
train_dataset = dataset.load_tfrecord_dataset(dataset_path, class_names_path, image_size)
train_dataset = train_dataset.shuffle(buffer_size=shuffle_buffer)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.map(lambda x, y: (
dataset.transform_images(x, image_size),
dataset.transform_targets(y, anchors, anchor_masks, image_size)))
train_dataset = train_dataset.prefetch(batch_size)
我的培训阶段开始:
history = model.fit(train_dataset,
epochs=epochs,
callbacks=callbacks,
validation_data=val_dataset)
【问题讨论】:
标签: python-3.x tensorflow deep-learning tensorflow2.0 tensorflow-datasets