【问题标题】:Which callback to apply with SWA: AverageModelCheckpoint or ModelCheckpoint?SWA 应用哪个回调:AverageModelCheckpoint 或 ModelCheckpoint?
【发布时间】:2020-12-17 20:47:30
【问题描述】:

我试图了解使用 SWA (Tensorflow addons implementation) 有和没有回调的区别。所以,我有两个实现,它们都工作得很好。但是,我不确定两者之间的区别。第一个是没有 swa 回调:

optimizer = tf.keras.optimizers.Adam(learning_rate=.01, beta_1=0.9, beta_2=0.999)
opt = tfa.optimizers.SWA(optimizer, start_averaging=start_epoch,
                   average_period=1, lr=0.005)
model.compile(loss=lossFunction, optimizer=opt, metrics=['accuracy'])
print(model.summary())
# simple early stopping
es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=20)
mc = tf.keras.callbacks.ModelCheckpoint(checkpointFilePath, monitor='val_accuracy',
                     mode='max', verbose=1, save_best_only=True)
history = model.fit(trainX, trainy, validation_data=(valX, valy),
                    batch_size=batch_size,
                    epochs=epochs, verbose=0, callbacks=[es, mc])

第二个是一样的,只是多了一个回调:

swa = tfa.callbacks.AverageModelCheckpoint(filepath=checkpointFilePath,
                                                    update_weights=True)

而当添加swa回调时,最后一行变化如下:

history = model.fit(trainX, trainy, validation_data=(valX, valy),
                    batch_size=batch_size,
                    epochs=epochs, verbose=0, callbacks=[es, mc, swa])

我的问题是这两种实现有什么区别?如果有的话,哪个更好?

【问题讨论】:

    标签: python tensorflow optimization callback tensorflow2.0


    【解决方案1】:

    ModelCheckpoint 回调让您可以在某个时间间隔保存模型或权重。如果您想稍后加载权重以从保存的状态继续训练,这很有用。

    但是,使用ModelCheckpoint,您无法在某个时间间隔保存移动平均权重。也就是说,模型平均优化器需要自定义回调来保存权重。使用AverageModelCheckpoint,您可以保存权重并可选择将此平均权重分配给模型。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2016-08-20
      • 1970-01-01
      • 1970-01-01
      • 2020-05-27
      • 2020-10-08
      • 1970-01-01
      • 2020-09-21
      • 1970-01-01
      相关资源
      最近更新 更多