【问题标题】:Save model every 10 epochs tensorflow.keras v2每 10 个 epoch 保存一次模型 tensorflow.keras v2
【发布时间】:2020-03-22 22:36:44
【问题描述】:

我正在使用在 tensorflow v2 中定义为子模块的 keras。我正在使用fit_generator() 方法训练我的模型。我想每 10 个 epoch 保存一次我的模型。我怎样才能做到这一点?

在 Keras(不作为 tf 的子模块)中,我可以给ModelCheckpoint(model_savepath,period=10)。但在 tf v2 中,他们已将其更改为 ModelCheckpoint(model_savepath, save_freq),其中 save_freq 可以是 'epoch',在这种情况下,模型会在每个时期保存。如果save_freq 是整数,则在处理完这么多样本后保存模型。但我希望它在 10 个时代之后。我怎样才能做到这一点?

【问题讨论】:

    标签: python keras deep-learning tensorflow2.0 tf.keras


    【解决方案1】:

    使用tf.keras.callbacks.ModelCheckpoint 使用save_freq='epoch' 并传递一个额外的参数period=10

    虽然official docs 中没有记录这一点,但这就是这样做的方式(注意记录了你可以通过period,只是没有解释它的作用)。

    【讨论】:

    • 我收到以下警告:WARNING:tensorflow:'period' argument is deprecated. Please use 'save_freq' to specify the frequency in number of samples seen. 所以,我想,此功能即将推出。在那种情况下,我怎样才能做到这一点?
    • 我相信唯一的选择是计算每个时期的示例数,并将该整数传递给 save_freq 乘以您想要的时期数作为保存之间的间隔
    • @bluesummers "examples per epoch" 这应该是我的批量大小,对吧?
    • Examples per epoch 是您希望在检查点之间通过网络的 samples 数量 - 这意味着如果您有 100 个样本(samples != batch,batch 是一批样本),你放 400 个,它将每 4 个 epoch 保存一次
    • 我遇到了与@NagabhushanSN 相同的问题。我计算了每个时期的样本数来计算我想保存模型的样本数,但它似乎不起作用。批量大小 = 64,对于测试用例,我每个 epoch 使用 10 个步骤。如果我想每 3 个 epoch 保存一次模型,样本数是 64*10*3=1920。我将它用于 sav_freq 但输出显示模型保存在 epoch 1、epoch 2、epoch 9、epoch 11、epoch 14 并且仍在运行。无法理解。 period 选项似乎工作正常,但有消息表明它将被弃用。
    【解决方案2】:

    显式计算每个时期的批次数对我有用。

    BATCH_SIZE = 20
    STEPS_PER_EPOCH = train_labels.size / BATCH_SIZE
    SAVE_PERIOD = 10
    
    # Create a callback that saves the model's weights every 10 epochs
    cp_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_path, 
        verbose=1, 
        save_weights_only=True,
        save_freq= int(SAVE_PERIOD * STEPS_PER_EPOCH))
    
    # Train the model with the new callback
    model.fit(train_images, 
              train_labels,
              batch_size=BATCH_SIZE,
              steps_per_epoch=STEPS_PER_EPOCH,
              epochs=50, 
              callbacks=[cp_callback],
              validation_data=(test_images,test_labels),
              verbose=0)
    

    【讨论】:

      猜你喜欢
      • 2018-11-07
      • 2018-12-13
      • 1970-01-01
      • 2018-06-16
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多