【问题标题】:Keras ModelCheckpoint not saving but EarlyStopping is working fine with the same monitor argumentKeras ModelCheckpoint 未保存,但 EarlyStopping 使用相同的监视器参数工作正常
【发布时间】:2021-01-25 19:26:29
【问题描述】:

我已经构建了一个模型,并且正在使用自定义函数进行验证。问题是:我的自定义验证功能将验证准确性保存在日志字典中,但 Keras ModelCheckpoint 不知何故看不到它。 EarlyStopping 工作正常。

这是验证类的代码:

class ValidateModel(keras.callbacks.Callback):
    
    def __init__(self, validation_data, loss_fnc):
        super().__init__()
        self.validation_data = validation_data
        self.loss_fnc = loss_fnc
    
    def on_epoch_end(self, epoch, logs={}):
        
        th = 0.5
        
        features = self.validation_data[0]
        y_true = self.validation_data[1].reshape((-1,1))     
        
        y_pred = np.asarray(self.model.predict(features)).reshape((-1,1))
        
        #Computing the validation loss.
        y_true_tensor = K.constant(y_true)
        y_pred_tensor = K.constant(y_pred)
        
        val_loss = K.eval(self.loss_fnc(y_true_tensor, y_pred_tensor))
        
        #Rounding the predicted values based on the threshold value.
        #Values lesser than th are rounded to 0, while values greater than th are rounded to 1.
        y_pred_rounded = y_pred / th
        y_pred_rounded = np.clip(np.floor(y_pred_rounded).astype(int),0,1)
        y_pred_rounded_tensor = K.constant(y_pred_rounded)
        
        val_acc = accuracy_score(y_true, y_pred_rounded)
        
        logs['val_loss'] = val_loss
        logs['val_acc'] = val_acc
        
        print(f'\nval_loss: {val_loss} - val_acc: {val_acc}')

这是我用来训练模型的函数:

def train_generator_model(model):
    steps = int(train_df.shape[0] / TRAIN_BATCH_SIZE)

    cb_validation = ValidateModel([validation_X, validation_y], iou)
    cb_early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_acc', 
                                                     patience=3, 
                                                     mode='max', 
                                                     verbose = 1)
    cb_model_checkpoint = tf.keras.callbacks.ModelCheckpoint('/kaggle/working/best_generator_model.hdf5',
                                                             monitor='val_acc',
                                                             save_best_only=True,
                                                             mode='max',
                                                             verbose=1)

    history = model.fit(
        x = train_datagen, 
        epochs = 2, ##Setting to 2 to test.
        callbacks = [cb_validation, cb_model_checkpoint, cb_early_stop], 
        verbose = 1,
        steps_per_epoch = steps)
    
    #model = tf.keras.models.load_model('/kaggle/working/best_generator_model.hdf5', custom_objects = {'iou':iou})
    #model.load_weights('/kaggle/working/best_generator_model.hdf5')
    
    return history

如果我将 ModelCheckpoint 参数“save_best_model”设置为 False,则模型完美保存。当训练结束并运行 history.history 时,我可以看到正在记录 val_loss,如下所示:

{'损失': [0.13096405565738678, 0.11926634609699249], 'binary_accuracy': [0.9692355990409851, 0.9716895818710327], 'val_loss': [0.23041087, 0.18325138], 'val_acc': [0.9453247578938803, 0.956172612508138]}

我正在使用 Tensorflow 2.3.1 并从 tensorflow 导入 keras

感谢任何帮助。谢谢!

【问题讨论】:

    标签: python tensorflow keras callback


    【解决方案1】:

    我检查了 Tensorflow 代码,发现 Tensorflow 和 Keras 之间不兼容。在 tensorflow.keras.callbacks 文件中,有以下代码:

    from keras.utils import tf_utils
    

    问题在于 keras.utils 中没有 tf_utils(至少在我使用的 Keras 2.4.3 中没有)。奇怪的是,没有抛出异常。

    修复 #1: 将以下代码添加到您的程序中:

    class ModelCheckpoint_tweaked(tf.keras.callbacks.ModelCheckpoint):
        def __init__(self,
                       filepath,
                       monitor='val_loss',
                       verbose=0,
                       save_best_only=False,
                       save_weights_only=False,
                       mode='auto',
                       save_freq='epoch',
                       options=None,
                       **kwargs):
            
            #Change tf_utils source package.
            from tensorflow.python.keras.utils import tf_utils
            
            super(ModelCheckpoint_tweaked, self).__init__(filepath,
                       monitor,
                       verbose,
                       save_best_only,
                       save_weights_only,
                       mode,
                       save_freq,
                       options,
                       **kwargs)
    

    然后使用这个新类作为 ModelCheckpoint 回调:

    cb_model_checkpoint = ModelCheckpoint_tweaked(file_name,
                                                  monitor='val_acc',
                                                  save_best_only=True,
                                                  mode='max',
                                                  verbose=1)
    

    修复 #2:

    Tensorflow 更新到版本 2.4.0。如果您使用自定义回调来计算监控参数,请将以下行添加到自定义回调 __init__() 函数中:

    self._supports_tf_logs = True
    

    如果您不添加此行,则日志不会在回调之间保留。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-06-27
      • 2013-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-04-26
      • 1970-01-01
      相关资源
      最近更新 更多