【发布时间】:2018-08-12 12:34:03
【问题描述】:
train_on_batch() 与 fit() 有何不同?什么情况下我们应该使用train_on_batch()?
【问题讨论】:
标签: machine-learning deep-learning keras
train_on_batch() 与 fit() 有何不同?什么情况下我们应该使用train_on_batch()?
【问题讨论】:
标签: machine-learning deep-learning keras
这个问题是simple answer from the primary author:
使用
fit_generator,您可以使用生成器来生成验证数据 好吧。一般来说,我建议使用fit_generator,但使用train_on_batch也可以。这些方法的存在只是为了 在不同的用例中方便,没有“正确”的方法。
train_on_batch 允许您根据您提供的样本集合明确更新权重,而无需考虑任何固定的批量大小。在您想要的情况下,您可以使用它:在明确的样本集合上进行训练。您可以使用该方法在传统训练集的多个批次上维护自己的迭代,但允许 fit 或 fit_generator 为您迭代批次可能更简单。
使用train_on_batch 可能会更好的一种情况是在新一批样本上更新预训练模型。假设您已经训练并部署了一个模型,并且稍后您收到了一组以前从未使用过的新训练样本。您可以使用train_on_batch 仅在这些样本上直接更新现有模型。其他方法也可以做到这一点,但在这种情况下使用train_on_batch 是相当明确的。
除了像这样的特殊情况(您有一些教学原因需要在不同的训练批次中维护自己的光标,或者对于特殊批次的某种类型的半在线培训更新),最好只始终使用fit(用于适合内存的数据)或fit_generator(用于流式传输数据作为生成器)。
【讨论】:
train_on_batch() 让您可以更好地控制 LSTM 的状态,例如,当使用有状态 LSTM 并需要控制对 model.reset_states() 的调用时。您可能有多系列数据,并且需要在每个系列之后重置状态,您可以使用 train_on_batch() 执行此操作,但如果您使用 .fit(),那么网络将在所有系列数据上进行训练,而无需重置状态。没有对错之分,这取决于您使用的数据以及您希望网络的行为方式。
【讨论】:
fit 强制它。
如果您使用大型数据集并且没有可轻松序列化的数据(如高等级 numpy 数组)来写入 tfrecord,那么与 fit 和 fit 生成器相比,Train_on_batch 的性能也会有所提高。
在这种情况下,您可以将数组保存为 numpy 文件并在内存中加载它们的较小子集(traina.npy、trainb.npy 等),当整个集合不适合内存时。然后,您可以使用 tf.data.Dataset.from_tensor_slices,然后将 train_on_batch 与您的子数据集一起使用,然后加载另一个数据集并再次调用 train on batch,等等,现在您已经对整个数据集进行了训练,并且可以准确控制多少和什么你的数据集训练你的模型。然后,您可以使用从数据集中获取的简单循环和函数来定义自己的时期、批量大小等。
【讨论】:
train_on_batch 对 RL 很重要,它一次训练 1 步,fit 会非常慢
@nbro 的回答确实有帮助,只是为了添加更多场景,假设您正在训练一些 seq to seq 模型或具有一个或多个编码器的大型网络。我们可以使用 train_on_batch 创建自定义训练循环,并使用我们的部分数据直接在编码器上进行验证,而无需使用回调。为复杂的验证过程编写回调可能很困难。有几种情况我们希望批量训练。
问候, 卡西克
【讨论】: