【问题标题】:How to train large dataset on tensorflow 2.x如何在 tensorflow 2.x 上训练大型数据集
【发布时间】:2020-10-19 06:48:01
【问题描述】:

我有一个包含大约 200 万行和 6,000 列的大型数据集。输入的 numpy 数组 (X, y) 可以很好地保存训练数据。但是当它转到 model.fit() 时,我得到一个 GPU Out-Of-Memory 错误。我正在使用张量流 2.2。根据其手册,model.fit_generator 已被弃用,而 model.fit 是首选。

有人可以概述使用 tensorflow v2.2 训练大型数据集的步骤吗?

【问题讨论】:

    标签: tensorflow tensorflow2.x


    【解决方案1】:

    最好的解决方案是使用tf.data.Dataset(),这样您就可以轻松地使用.batch() 方法对数据进行批处理。

    这里有很多教程,你可能想使用from_tensor_slices()直接玩numpy数组。

    下面有两个很好的文档可以满足您的需求。

    https://www.tensorflow.org/tutorials/load_data/numpy

    https://www.tensorflow.org/guide/data

    【讨论】:

    • 感谢您的指点。因此,在我的示例中,如果我想一次训练 200,000 行,我可以设置 shuffle_buffer_size = 2,000,000,然后设置 batch_size=200,000?
    • 是的,这就是它的工作方式。当然,batch_size 小得多的值也是可行的,比如 32、64 等。现在 shuffle_buffer_size = 2,000,000 可能需要很长时间,你可以选择一个更小的值。
    • 如果我的回答帮助您接受/支持它。谢谢。
    • 谢谢,卡林。如果我正确理解文档,这里涉及两个 batch_size。请让我知道我是否正确。有原始的 batch_size 用于选择小批量梯度计算中的行数。 32 是一个很好的数字。然后还有另一个与数据集关联的 batch_size。这是决定一次在 TF 内存中选择的行数。这个 batch_size 可以与 GPU 内存可以容纳的一样大。这是正确的吗?
    • 在这种情况下,您不需要在 model.fit() 中传递 batch_size 参数。它将自动使用您在 tf.data.Dataset().batch() 中使用的 BATCH_SIZE。至于你的另一个问题:批量大小超参数确实需要仔细调整。另一方面,如果你看到 OOM 错误,你应该减少它直到你没有得到 OOM(通常以这种方式 32 --> 16 --> 8 ...)。在您的情况下,我将从 2 的 batch_size 开始,将其增加 2 的幂,然后检查我是否仍然得到 OOM。
    猜你喜欢
    • 1970-01-01
    • 2020-09-23
    • 2021-05-13
    • 2021-08-20
    • 1970-01-01
    • 2017-12-11
    • 1970-01-01
    • 2021-11-23
    • 1970-01-01
    相关资源
    最近更新 更多