【问题标题】:keras model.fit_generator() several times slower than model.fit()keras model.fit_generator() 比 model.fit() 慢几倍
【发布时间】:2017-03-07 06:40:28
【问题描述】:

即使从 Keras 1.2.2 开始,引用 merge,它确实包含多处理,但由于磁盘读取速度限制,model.fit_generator() 仍然比 model.fit() 慢 4-5 倍。如何加快速度,比如通过额外的多处理?

【问题讨论】:

  • 这取决于瓶颈在哪里......如果是阅读速度限制,请增加batch_size以减慢您的训练步骤并增加队列大小和worker的nb。你是在 GPU 还是 CPU 上训练?
  • 如果您提供有关您的数据、批量大小、加载类型等的详细信息,那就太好了。
  • 训练在 GPU 上。我已将批量大小从 32、64 更改为 128,速度没有显着差异。
  • 按设计它应该更慢。 fit_generator 中存在大量与 I/O 相关的开销,而 fit() 中不存在这些开销。 SSD 可能是缓解这种情况的方法。

标签: python machine-learning tensorflow neural-network keras


【解决方案1】:

您可能想查看documentationfit_generator()workersmax_queue_size 参数。本质上,更多的workers 会创建更多的线程来将数据加载到向您的网络提供数据的队列中。但是,填充队列可能会导致内存问题,因此您可能需要减少 max_queue_size 以避免这种情况。

【讨论】:

    【解决方案2】:

    我有一个类似的问题,我切换到 dask 将数据加载到内存中,而不是使用我使用 pandas 的生成器。因此,根据您的数据大小,如果可能,将数据加载到内存中并使用 fit 函数。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-02-07
      • 1970-01-01
      • 2021-02-13
      • 2020-06-20
      • 2017-08-26
      • 2011-10-07
      • 2021-04-12
      • 1970-01-01
      相关资源
      最近更新 更多