【问题标题】:What does train_on_batch() do in keras model?train_on_batch() 在 keras 模型中做了什么?
【发布时间】:2018-07-11 00:53:53
【问题描述】:

我看到了一个代码示例(这里太大,无法粘贴),其中作者使用model.train_on_batch(in, out) 而不是model.fit(in, out)。 Keras 的官方文档说:

对一批样本进行单一梯度更新。

但我不明白。它是否与fit() 相同,但不是执行许多前馈和反向传播步骤,而是执行一次?还是我错了?

【问题讨论】:

标签: python tensorflow machine-learning keras artificial-intelligence


【解决方案1】:

是的,train_on_batch 只使用一个批次训练一次。

fit 为许多时期训练许多批次。 (每批都会导致权重更新)。

使用train_on_batch的想法可能是在每批之间自己做更多的事情。

【讨论】:

  • 能否请您添加我们使用 train_on_batch 的原因?当整个数据无法放入内存时使用它,因此我们将其分成块(批次)并使用 train_on_batch() 而不是 fit() ?
  • 如果您的数据不适合内存,您将使用:1 - 如果问题出在模型中,请使用较小的批次 // 2 - 如果问题是总 numpy 数组,请使用生成器和fit_generator。如果您想在批次之间手动执行操作,请仅使用 train_on_batch
【解决方案2】:

当我们想了解并在每次批量训练后进行一些自定义更改时使用它。

更精确的用例是 GAN。 您必须更新鉴别器,但在更新 GAN 网络期间,您必须使鉴别器无法训练。所以你首先训练判别器,然后训练 gan,使判别器无法训练。 看到这个以获得更多理解: https://medium.com/datadriveninvestor/generative-adversarial-network-gan-using-keras-ce1c05cfdfd3

【讨论】:

    【解决方案3】:

    模型的方法fit通过你给它的数据训练模型一次,但是由于内存(尤其是GPU内存)的限制,我们不能一次训练大量样本,所以我们需要将这些数据分成小块,称为小批量(或只是批量)。 keras模型的methode fit会为你做这个数据划分,把你给它的所有数据都传过去。

    但是,有时我们需要更复杂的训练过程,例如在每个 epoch 随机选择新样本放入批处理缓冲区(例如 GAN 训练和 Siamese CNNs 训练......),在这种情况下我们不使用花哨的简单拟合方法,但我们使用 train_on_batch 方法。为了使用这种方法,我们在每次迭代中生成一批输入和一批输出(标签)并将其传递给该方法,它将一次在批次中的整个样本上训练模型,并为我们提供损失和其他指标根据批次样本计算。

    【讨论】:

      猜你喜欢
      • 2018-08-12
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-12-09
      • 2011-08-01
      • 1970-01-01
      • 2019-01-28
      • 2020-01-29
      相关资源
      最近更新 更多