【发布时间】:2021-10-13 07:02:52
【问题描述】:
我有一个如下所示的生成器:
def data_generator(data_file, index_list,....):
orig_index_list = index_list
while True:
x_list = list()
y_list = list()
if patch_shape:
index_list = create_patch_index_list(orig_index_list, data_file, patch_shape,
patch_overlap, patch_start_offset,pred_specific=pred_specific)
else:
index_list = copy.copy(orig_index_list)
while len(index_list) > 0:
index = index_list.pop()
add_data(x_list, y_list, data_file, index, augment=augment, augment_flip=augment_flip,
augment_distortion_factor=augment_distortion_factor, patch_shape=patch_shape,
skip_blank=skip_blank, permute=permute)
if len(x_list) == batch_size or (len(index_list) == 0 and len(x_list) > 0):
yield convert_data(x_list, y_list, n_labels=n_labels, labels=labels, num_model=num_model,overlap_label=overlap_label)
x_list = list()
y_list = list()
我的数据集大小为 55GB,并存储为 .h5 文件 (data.h5)。读取数据时非常慢。一个 epoch 需要 7000 秒,并且在 6 个 epoch 之后出现分段错误。
我想如果我设置multi_processing = False和workers > 1会加快读取数据的速度:
model.fit(multi_processing = False, workers = 8)
但是当我这样做时,我收到以下错误:
RuntimeError: Your generator is NOT thread-safe. Keras requires a thread-safe generator when use_multiprocessing=False, workers > 1.
有没有办法让我的生成器线程安全?或者有没有其他有效的方法来生成这些数据?
【问题讨论】:
-
这能回答你的问题吗? Are Generators Threadsafe?。请特别查看以
LockedIterator类为特色的答案(第二个答案)。其实我觉得LockedIterator这个类是错的。 -
不,它没有。我尝试了其他发布的解决方案,但没有任何效果。我的问题是如何使上述生成器线程安全,以便我可以设置
use_multiprocessing=False, workers > 1并检查数据加载过程的速度是否有任何改进。我的最终目标是让训练更快,所以如果有人知道任何其他有效的加载数据的方法,那就更好了。 -
糟糕。我复制和粘贴不正确。我的回答如下。看看,让我知道你是否遵循。此外,如果您对效率有疑问,that 是另一个未来的帖子。不要捎带这样的问题。当他们提出多个问题时,帖子将被关闭。
标签: python multithreading multiprocessing generator