【问题标题】:Keras training with shuffled tf.data: if training is interrupted, how to continue training at last data iteration/order of last saved checkpoint使用打乱的 tf.data 进行 Keras 训练:如果训练中断,如何在最后一次数据迭代/最后保存的检查点的顺序处继续训练
【发布时间】:2020-10-13 03:59:18
【问题描述】:

我正在使用 keras model.fit 进行训练,数据来自 tf.records,加载到 tf.data 对象中,该对象使用 .shuffle 对数据进行洗牌。我还使用callbacks.ModelCheckpoint 以每x 步数/批次数保存模型。

有时我的云实例会在一个纪元完成之前断开连接或崩溃,但y 步骤的模型已保存到我的驱动器中。

我想在训练另一个 epoch 之前完成对那个 epoch 中数据的训练(我有很长的 epoch),所以每个数据示例在每个 epoch 训练一次。

有没有办法获取数据的原始顺序,以及模型最后保存在数据中的位置?

到目前为止我发现了什么

看起来您可以通过设置种子在 .shuffle 中设置特定顺序。但是,洗牌只发生在缓冲区中,所以我不能 100% 确定设置种子是否会完美地重现订单。另外,我不确定这将如何与reshuffle_each_iteration 一起使用。每个时期之后是否使用不同的种子?如果是这样,我想一个解决方法是一次只训练 1 个 epoch,每个 epoch 都有一个指定的种子。

即使我确实获得了训练顺序的副本,我也不确定如何在顺序中找到模型最后保存的位置,然后从该点开始训练。我必须得到的一个想法是手动遍历数据集,直到我到达它。虽然我不确定model.fit() 是否会继续执行此命令,或者重新开始。 F

为了从上次保存模型的位置获取步骤/批次编号,我可能可以将其记录在某个地方。

这些解决方案似乎是粗略的解决方法,我想知道 Keras 中是否有一些我可能忽略的功能可以帮助解决这个问题。

【问题讨论】:

  • 我认为没有为此设计的功能,因为在大多数情况下这不是问题。从它在一个纪元中间停止的位置继续训练而不是重新开始一个新纪元似乎仅用于与其他算法/超参数集进行比较。而且在大多数情况下,epoch 足够短,只能在 epoch 结束时而不是在它中间来减轻重量。
  • 我不确定您是否可以恢复数据集的迭代,但是您可以使用 EarlyStopping restore_best_weights=True 这样即使您重新启动,您也将使用最好的检查点而不是最后一个模型仅限检查点

标签: tensorflow keras tensorflow2.0 tensorflow-datasets tf.keras


【解决方案1】:

似乎没有 keras 构建方法可以做到这一点,但如果我错了,请纠正我。

我的方法

Dataset.shuffle 在内部使用初始种子值生成种子,用于在迭代期间重新洗牌reshuffle_each_iteration=True。因此,为特定时期重新创建相同的顺序并在该特定批次继续训练时期,我们必须重新创建具有相同种子的数据集并将数据集迭代器移动到相同时期和相同批次。

调试

为了调试并确保以相同的顺序生成 epoch 和 batch,我们需要一种方法来打印每个 epoch-batch 中数据点的拾取方式。这很棘手,因此出于调试目的,我将使用回归问题并将基本事实作为序列号。然后我可以有一个自定义损失,我可以在其中打印基本事实并使用户的顺序正确。

模型和数据

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import keras.backend as K


# Data
x_train = np.random.randn(15, 10).astype("float32")
y_train = np.arange(15).astype("float32")

# Custom MSE looss just to track the order in which data is picked up
def my_mse(y_true, y_pred):
    tf.print(tf.keras.backend.flatten(y_true))
    loss = K.square(y_pred - y_true)
    loss = K.sum(loss, axis=1)
    return loss

# Model
def get_model():
    inputs = keras.Input(shape=(10))    
    outputs = layers.Dense(1, activation="linear")(inputs)
    model = keras.Model(inputs=inputs, outputs=outputs)
    
    model.compile(
        optimizer="rmsprop",
        loss=my_mse,
    )
    return model

数据集

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(8)

epochs = 2

print ("Runs 1")
for e in range(epochs):
  for i, (x, y) in enumerate(train_dataset):
    print (e, i, y)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(8)
print ("Runs 2")
for e in range(epochs):
  for i, (x, y) in enumerate(train_dataset):
    print (e, i, y)

输出:

Runs 1
0 tf.Tensor([1. 3. 5. 7. 4. 0. 8. 2.], shape=(8,), dtype=float32)
1 tf.Tensor([ 6. 11. 10. 14.  9. 12. 13.], shape=(7,), dtype=float32)
2 tf.Tensor([4. 2. 5. 8. 1. 9. 7. 3.], shape=(8,), dtype=float32)
3 tf.Tensor([13. 10.  0. 14.  6. 11. 12.], shape=(7,), dtype=float32)
4 tf.Tensor([ 0.  1.  5.  6.  9.  3.  7. 14.], shape=(8,), dtype=float32)
5 tf.Tensor([13.  8.  4. 10.  2. 12. 11.], shape=(7,), dtype=float32)
Runs 2
0 tf.Tensor([1. 3. 5. 7. 4. 0. 8. 2.], shape=(8,), dtype=float32)
1 tf.Tensor([ 6. 11. 10. 14.  9. 12. 13.], shape=(7,), dtype=float32)
2 tf.Tensor([4. 2. 5. 8. 1. 9. 7. 3.], shape=(8,), dtype=float32)
3 tf.Tensor([13. 10.  0. 14.  6. 11. 12.], shape=(7,), dtype=float32)
4 tf.Tensor([ 0.  1.  5.  6.  9.  3.  7. 14.], shape=(8,), dtype=float32)
5 tf.Tensor([13.  8.  4. 10.  2. 12. 11.], shape=(7,), dtype=float32)

是的,使用种子复制订单。

现在让我们编写一个方法将数据集转发到某个时期和批次组合

def forward(dataset, n=None):
  if not n:
    return dataset

  i = 0  
  while True:
    for _ in dataset:        
        i += 1
        if i == n:
          return dataset

测试用例:

让我们正常运行并观察顺序

从头开始的数据

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), None)

model = get_model()
model.fit(train_dataset, epochs=3, verbose=0, workers=4, shuffle=False)

输出:

[7 3 6 10]
[11 0 1 2]
[8 14 9 13]
[12 5 4]
[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]

数据集第 n 个状态的数据

让我们的数据集进行第 4 次迭代并运行训练

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), 4)

model = get_model()
model.fit(train_dataset, epochs=3, verbose=0, workers=4, shuffle=False)

输出:

[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]

很好,现在我们知道如何正确转发数据集了。现在让我们编写回调来跟踪当前的迭代次数:

跟踪迭代的自定义回调(epoch-batch 组合)

现在我们需要确定模型被检查指向的时期和批次组合。如果我们有这些信息,我们可以加载最后一个检查点模型并将我们的数据集转发到它的批次和时期组合并继续训练。我们将使用回调来做到这一点

class MyCustomCallback(tf.keras.callbacks.ModelCheckpoint, keras.callbacks.Callback):
    def __init__(self, the_id=0, **args):
      self.the_id = the_id
      self.epoch = 0
      super().__init__(**args)

    def _save_model(self, epoch, logs):
      logs['the_id'] = self.the_id
      super()._save_model(epoch, logs)

    def on_batch_end(self, batch, logs={}):
      self.the_id += 1
      super().on_batch_end(batch, logs)

checkpoint_filepath = 'checkpoint-{the_id}'
model_checkpoint_callback = MyCustomCallback(
    filepath=checkpoint_filepath,
    save_freq=2,
    save_best_only=False)

model = get_model()

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), None)

model.fit(train_dataset, epochs=5, verbose=0, callbacks=[model_checkpoint_callback], workers=4, shuffle=False)

输出:

[7 3 6 10]
[11 0 1 2]
[8 14 9 13]
[12 5 4]
[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]

我们每两批检查一次。所以让我们假设它崩溃并且最后一个检查点是checkpoint-4。我们可以加载这个模型并将我们的数据集转发到 4 并继续训练。

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), 4)

model = get_model()
model.fit(train_dataset, epochs=2, verbose=0, workers=4, shuffle=False)

输出:

[5 8 6 3]
[1 12 10 9]
[2 11 0 4]
[14 13 7]
[2 3 0 10]
[4 1 13 6]
[8 7 14 11]
[12 5 9]

【讨论】:

    【解决方案2】:

    我想你想恢复 shuffle order 以避免在这个 epoch 内重复一些样本。

    根据shuffle description,在未完成的时期,您的模型只能访问数据集中的第一个 current_step_number + shuffle_buffer_size 样本。

    因此,当您恢复训练时,如果您知道处理了多少步,则可以跳过此步骤 + 跳过 shuffle_buffer_size 步骤,然后您将继续在以下样本上进行训练,这在当前 epoch 内尚未观察到。

    请注意,在此时期根本不会观察到来自数据集第一部分的一些随机 shuffle_buffer_size 样本。正如你所说,你的时代很长,所以,可能你有很多数据,所以丢失 shuffle_buffer_size 样本对你来说应该不是问题。

    所以在保存检查点的过程中也要保存步数,然后在加载检查点后创建带有跳过步骤的数据集副本(使用 dataset.skip),然后将 model.fit 与这个较小的数据集一起使用一个时期(以完成当前时期),然后继续以平常的方式进行训练。

    【讨论】:

      猜你喜欢
      • 2021-08-06
      • 1970-01-01
      • 2020-10-28
      • 2020-12-13
      • 2020-10-29
      • 2022-10-17
      • 2018-10-07
      • 2018-01-05
      • 2017-07-28
      相关资源
      最近更新 更多