【问题标题】:How to prevent Keras from computing metrics during training如何防止 Keras 在训练期间计算指标
【发布时间】:2022-07-19 22:38:16
【问题描述】:

我正在使用 Tensorflow/Keras 2.4.1,并且我有一个(无监督的)自定义指标,它将我的几个模型输入作为参数,例如:

model = build_model() # returns a tf.keras.Model object
my_metric = custom_metric(model.output, model.input[0], model.input[1])
model.add_metric(my_metric)
[...]
model.fit([...]) # training with fit

但是,custom_metric 恰好非常昂贵,所以我希望仅在验证期间计算它。我找到了这个answer,但我几乎不明白如何使解决方案适应使用多个模型输入作为参数的指标,因为update_state 方法似乎不灵活。

在我的上下文中,除了编写我自己的训练循环之外,有没有办法避免在训练期间计算我的指标? 另外,我很惊讶我们不能在原生的 Tensorflow 中指定某些指标只能在验证时计算,这有什么原因吗?

此外,由于模型经过训练以优化损失,并且训练数据集不应该用于评估模型,我什至不明白为什么默认情况下 Tensorflow 在训练期间计算指标。

【问题讨论】:

    标签: python tensorflow machine-learning keras deep-learning


    【解决方案1】:

    我认为仅在验证时计算指标的最简单解决方案是使用自定义回调。

    这里我们定义了我们的虚拟回调:

    class MyCustomMetricCallback(tf.keras.callbacks.Callback):
    
        def __init__(self, train=None, validation=None):
            super(MyCustomMetricCallback, self).__init__()
            self.train = train
            self.validation = validation
    
        def on_epoch_end(self, epoch, logs={}):
    
            mse = tf.keras.losses.mean_squared_error
    
            if self.train:
                logs['my_metric_train'] = float('inf')
                X_train, y_train = self.train[0], self.train[1]
                y_pred = self.model.predict(X_train)
                score = mse(y_train, y_pred)
                logs['my_metric_train'] = np.round(score, 5)
    
            if self.validation:
                logs['my_metric_val'] = float('inf')
                X_valid, y_valid = self.validation[0], self.validation[1]
                y_pred = self.model.predict(X_valid)
                val_score = mse(y_pred, y_valid)
                logs['my_metric_val'] = np.round(val_score, 5)
    

    鉴于这个虚拟模型:

    def build_model():
    
      inp1 = Input((5,))
      inp2 = Input((5,))
      out = Concatenate()([inp1, inp2])
      out = Dense(1)(out)
    
      model = Model([inp1, inp2], out)
      model.compile(loss='mse', optimizer='adam')
    
      return model
    

    还有这个数据:

    X_train1 = np.random.uniform(0,1, (100,5))
    X_train2 = np.random.uniform(0,1, (100,5))
    y_train = np.random.uniform(0,1, (100,1))
    
    X_val1 = np.random.uniform(0,1, (100,5))
    X_val2 = np.random.uniform(0,1, (100,5))
    y_val = np.random.uniform(0,1, (100,1))
    

    您可以使用自定义回调来计算训练和验证的指标:

    model = build_model()
    
    model.fit([X_train1, X_train2], y_train, epochs=10, 
              callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train), validation=([X_val1, X_val2],y_val))])
    

    仅在验证时:

    model = build_model()
    
    model.fit([X_train1, X_train2], y_train, epochs=10, 
              callbacks=[MyCustomMetricCallback(validation=([X_val1, X_val2],y_val))])
    

    只在火车上:

    model = build_model()
    
    model.fit([X_train1, X_train2], y_train, epochs=10, 
              callbacks=[MyCustomMetricCallback(train=([X_train1, X_train2],y_train))])
    

    请记住,回调会一次性评估数据上的指标,就像 keras 在 validation_data 上默认计算的任何指标/损失一样。

    here 是运行代码。

    【讨论】:

      【解决方案2】:

      我可以使用learning_phase,但只能在符号张量模式(图形)模式下使用:

      所以,首先我们需要禁用 Eager 模式(这必须在导入 tensorflow 后立即完成):

      import tensorflow as tf
      tf.compat.v1.disable_eager_execution()
      

      然后您可以使用符号 if (backend.switch) 创建指标:

      def metric_graph(in1, in2, out):
          actual_metric = out * (in1 + in2)
          return K.switch(K.learning_phase(), tf.zeros((1,)), actual_metric) 
      

      方法add_metric 将要求提供名称和聚合方法,您可以将其设置为"mean"

      所以,这里有一个例子:

      x1 = numpy.ones((5,3))
      x2 = numpy.ones((5,3))
      y = 3*numpy.ones((5,1))
      
      vx1 = numpy.ones((5,3))
      vx2 = numpy.ones((5,3))
      vy = 3*numpy.ones((5,1))
      
      def metric_eager(in1, in2, out):
          if (K.learning_phase()):
              return 0
          else:
              return out * (in1 + in2)
      
      def metric_graph(in1, in2, out):
          actual_metric = out * (in1 + in2)
          return K.switch(K.learning_phase(), tf.zeros((1,)), actual_metric) 
      
      ins1 = Input((3,))
      ins2 = Input((3,))
      outs = Concatenate()([ins1, ins2])
      outs = Dense(1)(outs)
      model = Model([ins1, ins2],outs)
      model.add_metric(metric_graph(ins1, ins2, outs), name='my_metric', aggregation='mean')
      model.compile(loss='mse', optimizer='adam')
      
      model.fit([x1, x2],y, validation_data=([vx1, vx2], vy), epochs=3)
      

      【讨论】:

      • 感谢您的评论,但我无法禁用急切执行,因为我的一些损失组件需要对模型的输出进行索引,而这在符号张量上是不可能的。我知道在训练时应该避免急切的执行,因为它会产生性能问题,但我没有为我的项目找到其他解决方案(这将是另一个线程。)
      【解决方案3】:

      由于指标是在keras.Modeltrain_step 函数中运行的,因此在不更改 API 的情况下过滤掉火车禁用指标需要对keras.Model 进行子类化。

      我们定义了一个简单的度量包装器:

      class TrainDisabledMetric(Metric):
      
        def __init__(self, metric: Metric):
          super().__init__(name=metric.name)
          self._metric = metric
      
        def update_state(self, *args, **kwargs):
          return self._metric.update_state(*args, **kwargs)
      
        def reset_state(self):
          return self._metric.reset_state()
      
        def result(self):
          return self._metric.result()
      

      和子类keras.Model 在训练期间过滤掉这些指标:

      class CustomModel(keras.Model):
      
        def __init__(self, *args, **kwargs):
          super().__init__(*args, **kwargs)
      
        def compile(self, optimizer='rmsprop', loss=None, metrics=None,
                    loss_weights=None, weighted_metrics=None, run_eagerly=None,
                    steps_per_execution=None, jit_compile=None, **kwargs):
      
          from_serialized = kwargs.get('from_serialized', False)
      
          super().compile(optimizer, loss, metrics=metrics, loss_weights=loss_weights,
                          weighted_metrics=weighted_metrics, run_eagerly=run_eagerly,
                          steps_per_execution=steps_per_execution,
                          jit_compile=jit_compile, **kwargs)
      
          self.on_train_compiled_metrics = self.compiled_metrics
      
          if metrics is not None:
      
            def get_on_train_traverse_tree(structure):
              flat = tf.nest.flatten(structure)
              on_train = [not isinstance(e, TrainDisabledMetric) for e in flat]
              full_tree = tf.nest.pack_sequence_as(structure, on_train)
              return get_traverse_shallow_structure(lambda s: any(tf.nest.flatten(s)),
                                                    full_tree)
      
            on_train_sub_tree = get_on_train_traverse_tree(metrics)
            flat_on_train = flatten_up_to(on_train_sub_tree, metrics)
      
            def clean_tree(tree):
              if isinstance(tree, list):
                _list = []
                for t in tree:
                  r = clean_tree(t)
                  if r:
                    _list.append(r)
                return _list
      
              elif isinstance(tree, dict):
                _tree = {}
                for k, v in tree.items():
                  r = clean_tree(v)
                  if r:
                    _tree[k] = r
                return _tree
              else:
                return tree
      
            pruned_on_train_sub_tree = clean_tree(on_train_sub_tree)
            pruned_flat_on_train = [m for keep, m in
                                    zip(tf.nest.flatten(on_train_sub_tree),
                                        flat_on_train) if keep]
      
            on_train_metrics = tf.nest.pack_sequence_as(pruned_on_train_sub_tree,
                                                        pruned_flat_on_train)
      
            self.on_train_compiled_metrics = compile_utils.MetricsContainer(
              on_train_metrics, weighted_metrics=None, output_names=self.output_names,
              from_serialized=from_serialized)
      
        def train_step(self, data):
          x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
          # Run forward pass.
          with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compute_loss(x, y, y_pred, sample_weight)
          self._validate_target_and_loss(y, loss)
          # Run backwards pass.
          self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
          return self.compute_metrics(x, y, y_pred, sample_weight, training=True)
      
        def compute_metrics(self, x, y, y_pred, sample_weight, training=False):
          del x  # The default implementation does not use `x`.
      
          if training:
            self.on_train_compiled_metrics.update_state(y, y_pred, sample_weight)
            metrics = self.on_train_metrics
          else:
            self.compiled_metrics.update_state(y, y_pred, sample_weight)
            metrics = self.metrics
          # Collect metrics to return
          return_metrics = {}
          for metric in metrics:
            result = metric.result()
            if isinstance(result, dict):
              return_metrics.update(result)
            else:
              return_metrics[metric.name] = result
          return return_metrics
      
        @property
        def on_train_metrics(self):
          metrics = []
          if self._is_compiled:
            # TODO(omalleyt): Track `LossesContainer` and `MetricsContainer` objects
            # so that attr names are not load-bearing.
            if self.compiled_loss is not None:
              metrics += self.compiled_loss.metrics
            if self.compiled_metrics is not None:
              metrics += self.on_train_compiled_metrics.metrics
      
          for l in self._flatten_layers():
            metrics.extend(l._metrics)  # pylint: disable=protected-access
          return metrics
      

      现在给定一个 keras 模型,我们可以包装它并使用禁用训练的指标对其进行编译:

      model: keras.Model = ...
      custom_model = CustomModel(inputs=model.input, outputs=model.output)
      
      train_enabled_metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
      
      # wrap train disabled metrics with `TrainDisabledMetric`:
      train_disabled_metrics = [
        TrainDisabledMetric(tf.keras.metrics.SparseCategoricalCrossentropy())]
      
      metrics = train_enabled_metrics + train_disabled_metrics
      
      custom_model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
                           loss=tf.keras.losses.SparseCategoricalCrossentropy(
                             from_logits=True), metrics=metrics, )
      
      custom_model.fit(ds_train, epochs=6, validation_data=ds_test, )
      

      指标SparseCategoricalCrossentropy 仅在验证期间计算:

      Epoch 1/6
      469/469 [==============================] - 2s 2ms/step - loss: 0.3522 - sparse_categorical_accuracy: 0.8366 - val_loss: 0.1978 - val_sparse_categorical_accuracy: 0.9086 - val_sparse_categorical_crossentropy: 1.3197
      Epoch 2/6
      469/469 [==============================] - 1s 1ms/step - loss: 0.1631 - sparse_categorical_accuracy: 0.9526 - val_loss: 0.1429 - val_sparse_categorical_accuracy: 0.9587 - val_sparse_categorical_crossentropy: 1.1910
      Epoch 3/6
      469/469 [==============================] - 1s 1ms/step - loss: 0.1178 - sparse_categorical_accuracy: 0.9654 - val_loss: 0.1139 - val_sparse_categorical_accuracy: 0.9661 - val_sparse_categorical_crossentropy: 1.1369
      Epoch 4/6
      469/469 [==============================] - 1s 1ms/step - loss: 0.0909 - sparse_categorical_accuracy: 0.9735 - val_loss: 0.0981 - val_sparse_categorical_accuracy: 0.9715 - val_sparse_categorical_crossentropy: 1.0434
      Epoch 5/6
      469/469 [==============================] - 1s 1ms/step - loss: 0.0735 - sparse_categorical_accuracy: 0.9784 - val_loss: 0.0913 - val_sparse_categorical_accuracy: 0.9721 - val_sparse_categorical_crossentropy: 0.9862
      Epoch 6/6
      469/469 [==============================] - 1s 1ms/step - loss: 0.0606 - sparse_categorical_accuracy: 0.9823 - val_loss: 0.0824 - val_sparse_categorical_accuracy: 0.9761 - val_sparse_categorical_crossentropy: 1.0024
      
      
      

      【讨论】:

        猜你喜欢
        • 2017-10-11
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2020-03-20
        • 2022-12-20
        • 2018-07-05
        • 1970-01-01
        • 2017-07-12
        相关资源
        最近更新 更多