【问题标题】:Tensorflow 2 ModelCheckpoint callback with multiclass recall custom metric具有多类召回自定义指标的 Tensorflow 2 ModelCheckpoint 回调
【发布时间】:2020-09-21 23:27:42
【问题描述】:

我正在为多类分类任务 (num_classes=7) 构建一个 CNN 分类器。由于不平衡和主题领域,我对这项任务的目标指标是跨类的宏观平均召回率。

当模型训练时,如果验证多类宏召回被评估为高于之前在整个 epoch 中看到的最高值,我想通过在每个 epoch 结束时保存模型来检查它。我相信这将分两个阶段进行:

  1. 创建自定义指标,用于计算每个 epoch 结束时验证数据的多类场景的类之间的平均召回率
  2. 创建一个 ModelCheckpoint 回调来跟踪自定义指标并在模型超过之前的最大值时保存模型。

有人有这个或类似的例子吗?我对宏平均多类召回的自定义指标的实现更感兴趣,因为我相信一旦在 model.compile() 中定义了这个指标,回调就可以轻松完成

【问题讨论】:

    标签: python keras tensorflow2.0


    【解决方案1】:

    我使用this post 实现了自定义指标,并进行了一些调整,例如计算了running mean。以下是自定义指标的代码:

    import tensorflow.keras.backend as K
    from tensorflow.keras.metrics import Metric
    
    class MacroAverageRecall( Metric ):
        """Custom metric for calculating multiclass recall during         
    training"""
        def __init__(self,
                     num_classes,
                     batch_size,
                     name='multiclass_recall',
                     **kwargs):
            super( MacroAverageRecall, self ).__init__( name=name, **kwargs )
            self.batch_size = batch_size
            self.num_classes = num_classes
            self.num_batches = 0
            self.average_recall = self.add_weight( name="recall", initializer="zeros" )
    
        def update_state(self, y_true, y_pred, sample_weight=None):
            recall = 0
            pred = K.argmax( y_pred, axis=-1 )
            true = K.argmax( y_true, axis=-1 )
    
            for i in range( self.num_classes ):
                # Find where the pred equals the class
                predicted_instances_bool = K.equal(
                    pred,
                    i
                )
                # Find where the labels equals the class
                true_instances_bool = K.equal(
                    true,
                    i
                )
                # Converting tensors of bools to int (1,0)
                predicted_instances = K.cast(
                    predicted_instances_bool,
                    'float32'
                )
                true_instances = K.cast(
                    true_instances_bool,
                    'float32'
                )
                # Reshaping tensors
                true_reshaped = K.reshape(
                    true_instances,
                    (1, -1)
                )
                predicted_reshaped = K.reshape(
                    predicted_instances,
                    (-1, 1)
                )
                # Find true positives
                true_positives = K.dot(
                    true_reshaped,
                    predicted_reshaped
                )
                # Compute the true positive
                pred_true_pos = K.sum(
                    true_positives
                )
                # divide by all positives in t
                all_true_positives = (K.sum( true_instances ) + K.epsilon())
                class_recall = pred_true_pos / all_true_positives
                recall += class_recall
    
            self.num_batches += 1
            avg_recall = recall / self.num_classes
            recall_update = (avg_recall - self.average_recall) / self.num_batches
            self.average_recall.assign_add( recall_update )
    
        def result(self):
            return self.average_recall
    
        def reset_states(self):
            # The state of the metric will be reset at the start of each epoch.
            self.average_recall.assign( 0. )
    

    以及模型训练时使用的检查点:

    callbacks.ModelCheckpoint(
                filepath=os.path.join(
                    self._metadata['checkpoint_directory'],
                    f'checkpoint-{self._metadata["create_time"]}.h5' ),
                save_best_only=True if self._val else False,
                monitor='val_multiclass_recall',
                mode='max',
                verbose=1 )
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2023-02-01
      • 1970-01-01
      • 2018-01-25
      • 2019-10-09
      • 2021-06-03
      • 1970-01-01
      • 1970-01-01
      • 2017-11-29
      相关资源
      最近更新 更多