【问题标题】:How to make a generator thread-safe?如何使生成器线程安全?
【发布时间】:2021-10-13 07:02:52
【问题描述】:

我有一个如下所示的生成器:

def data_generator(data_file, index_list,....):
      orig_index_list = index_list
    while True:
        x_list = list()
        y_list = list()
        if patch_shape:
            index_list = create_patch_index_list(orig_index_list, data_file, patch_shape,
                                                 patch_overlap, patch_start_offset,pred_specific=pred_specific)
        else:
            index_list = copy.copy(orig_index_list)

        while len(index_list) > 0:
            index = index_list.pop()
            add_data(x_list, y_list, data_file, index, augment=augment, augment_flip=augment_flip,
                     augment_distortion_factor=augment_distortion_factor, patch_shape=patch_shape,
                     skip_blank=skip_blank, permute=permute)
            if len(x_list) == batch_size or (len(index_list) == 0 and len(x_list) > 0):
                yield convert_data(x_list, y_list, n_labels=n_labels, labels=labels, num_model=num_model,overlap_label=overlap_label)
                x_list = list()
                y_list = list()

我的数据集大小为 55GB,并存储为 .h5 文件 (data.h5)。读取数据时非常慢。一个 epoch 需要 7000 秒,并且在 6 个 epoch 之后出现分段错误。

我想如果我设置multi_processing = Falseworkers > 1会加快读取数据的速度:

model.fit(multi_processing = False, workers = 8)

但是当我这样做时,我收到以下错误:

RuntimeError: Your generator is NOT thread-safe. Keras requires a thread-safe generator when use_multiprocessing=False, workers > 1.

有没有办法让我的生成器线程安全?或者有没有其他有效的方法来生成这些数据?

【问题讨论】:

  • 这能回答你的问题吗? Are Generators Threadsafe?。请特别查看以LockedIterator 类为特色的答案(第二个答案)。其实我觉得LockedIterator这个类是错的。
  • 不,它没有。我尝试了其他发布的解决方案,但没有任何效果。我的问题是如何使上述生成器线程安全,以便我可以设置use_multiprocessing=False, workers > 1 并检查数据加载过程的速度是否有任何改进。我的最终目标是让训练更快,所以如果有人知道任何其他有效的加载数据的方法,那就更好了。
  • 糟糕。我复制和粘贴不正确。我的回答如下。看看,让我知道你是否遵循。此外,如果您对效率有疑问,that 是另一个未来的帖子。不要捎带这样的问题。当他们提出多个问题时,帖子将被关闭。

标签: python multithreading multiprocessing generator


【解决方案1】:

我认为我在上面的评论中引用的 LockedIterator 类是不正确的,应该如下例所示:

import threading

class LockedIterator(object):
    def __init__(self, it):
        self.lock = threading.Lock()
        self.it = iter(it)

    def __iter__(self): return self

    def __next__(self):
        with self.lock:
            return self.it.__next__()
            
def gen():
    for x in range(10):
        yield x

new_gen = LockedIterator(gen())

def worker(g):
    for x in g:
        print(x, flush=True)

t1 = threading.Thread(target=worker, args=(new_gen,))
t2 = threading.Thread(target=worker, args=(new_gen,))
t1.start()
t2.start()
t1.join()
t2.join()

打印:

0
1
23

4
5
6
7
8
9

如果您想保证打印输出每行打印一个值,那么我们还需要将threading.Lock 实例传递给每个线程并在该锁的控制下发出print 语句以便打印被序列化了。

【讨论】:

  • 我不明白如何将其适应我在问题中发布的生成器。你能再解释一下吗?
  • 您目前有 data_generator(actual parameters) 的地方使用 LockedIterator(data_generator(actual parameters))。当然,如果您发布了这个生成器是如何被引用的,那将会有所帮助。不幸的是,我对 Keras 并不熟悉。
  • 我问了另一个关于效率的问题。在那里,我提供了有关如何引用此生成器的更多详细信息:stackoverflow.com/questions/68705944/…
  • 所以你会想要 this 问题training_generator = LockedIterator(data_generator(data_file, training_list,....)) 并且你会将training_generator 传递给每个线程。
  • 我按照您的建议编辑了:training_generator = LockedIterator(data_generator(data_file, training_list,....))。然后model.fit(use_multiprocessing=False, workers =8)。它还没有给我一个错误。等待第一个 epoch 完成以检查速度是否有任何提高。
猜你喜欢
  • 2017-05-02
  • 2010-11-11
  • 2017-04-20
  • 1970-01-01
  • 1970-01-01
  • 2012-02-07
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多