【问题标题】:What is the use of train_on_batch() in keras?keras 中 train_on_batch() 有什么用?
【发布时间】:2018-08-12 12:34:03
【问题描述】:

train_on_batch()fit() 有何不同?什么情况下我们应该使用train_on_batch()

【问题讨论】:

标签: machine-learning deep-learning keras


【解决方案1】:

这个问题是simple answer from the primary author:

使用fit_generator,您可以使用生成器来生成验证数据 好吧。一般来说,我建议使用fit_generator,但使用 train_on_batch 也可以。这些方法的存在只是为了 在不同的用例中方便,没有“正确”的方法。

train_on_batch 允许您根据您提供的样本集合明确更新权重,而无需考虑任何固定的批量大小。在您想要的情况下,您可以使用它:在明确的样本集合上进行训练。您可以使用该方法在传统训练集的多个批次上维护自己的迭代,但允许 fitfit_generator 为您迭代批次可能更简单。

使用train_on_batch 可能会更好的一种情况是在新一批样本上更新预训练模型。假设您已经训练并部署了一个模型,并且稍后您收到了一组以前从未使用过的新训练样本。您可以使用train_on_batch 仅在这些样本上直接更新现有模型。其他方法也可以做到这一点,但在这种情况下使用train_on_batch 是相当明确的。

除了像这样的特殊情况(您有一些教学原因需要在不同的训练批次中维护自己的光标,或者对于特殊批次的某种类型的半在线培训更新),最好只始终使用fit(用于适合内存的数据)或fit_generator(用于流式传输数据作为生成器)。

【讨论】:

  • 谢谢。请在 Keras 中实现批量训练?
【解决方案2】:

train_on_batch() 让您可以更好地控制 LSTM 的状态,例如,当使用有状态 LSTM 并需要控制对 model.reset_states() 的调用时。您可能有多系列数据,并且需要在每个系列之后重置状态,您可以使用 train_on_batch() 执行此操作,但如果您使用 .fit(),那么网络将在所有系列数据上进行训练,而无需重置状态。没有对错之分,这取决于您使用的数据以及您希望网络的行为方式。

【讨论】:

  • 正是我的用例,我正在搜索这个问题,看看这样做是否有意义,因为我有一段时间试图用fit 强制它。
【解决方案3】:

如果您使用大型数据集并且没有可轻松序列化的数据(如高等级 numpy 数组)来写入 tfrecord,那么与 fit 和 fit 生成器相比,Train_on_batch 的性能也会有所提高。

在这种情况下,您可以将数组保存为 numpy 文件并在内存中加载它们的较小子集(traina.npy、trainb.npy 等),当整个集合不适合内存时。然后,您可以使用 tf.data.Dataset.from_tensor_slices,然后将 train_on_batch 与您的子数据集一起使用,然后加载另一个数据集并再次调用 train on batch,等等,现在您已经对整个数据集进行了训练,并且可以准确控制多少和什么你的数据集训练你的模型。然后,您可以使用从数据集中获取的简单循环和函数来定义自己的时期、批量大小等。

【讨论】:

  • 这个train_on_batch 对 RL 很重要,它一次训练 1 步,fit 会非常慢
【解决方案4】:

@nbro 的回答确实有帮助,只是为了添加更多场景,假设您正在训练一些 seq to seq 模型或具有一个或多个编码器的大型网络。我们可以使用 train_on_batch 创建自定义训练循环,并使用我们的部分数据直接在编码器上进行验证,而无需使用回调。为复杂的验证过程编写回调可能很困难。有几种情况我们希望批量训练。

问候, 卡西克

【讨论】:

    猜你喜欢
    • 2018-07-11
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-04-10
    • 2016-09-29
    • 2020-08-09
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多