【问题标题】:Custom train_step with combination of losses and regularization loss结合损失和正则化损失的自定义 train_step
【发布时间】:2022-11-07 20:33:29
【问题描述】:

我需要在我的 Tensorflow 的model.fit 中使用自定义train_step 来使用两个损失的线性组合作为损失函数,即交叉熵(在监督分类任务中通常)和另一个可能是任何损失的损失 - 不是对我的问题非常重要。此外,我想仍然使用我在模型层中定义的regularization_loss(例如, L2 regularization)。我想知道以下代码是否正确实现了我想做的事情。

特别是,由于 L2 正则化增加了损失的惩罚(所以在这种情况下,ce_loss 应该已经包含 L2 正则化项),我认为将additional_loss 添加到ce_loss 是正确的。这个对吗?

import tensorflow as tf


class CustomModel(tf.keras.Model):

    def __init__(self, model):
        super(CustomModel, self).__init__()
        self.model = model

    def compile(self, optimizer, loss, metrics, another_loss, gamma=0.2):
        super(CustomModel, self).compile(optimizer=optimizer, loss=loss, metrics=metrics)
        self.another_loss= another_loss
        # gamma weights the impact of another_loss
        self.gamma = gamma

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self.model(x, training=True)  # Forward pass

            # Compute the loss value
            # (the loss function is configured in `compile()`)
            ce_loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

            additional_loss = self.another_loss(y, y_pred)
            combined_loss = ce_loss + self.gamma * additional_loss 


        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(combined_loss , trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        x, y = data
        y_pred = self.model(x, training=False)  # Forward pass
        self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        self.compiled_metrics.update_state(y, y_pred)
        # self.compiled_metrics
        return {m.name: m.result() for m in self.metrics}

【问题讨论】:

    标签: python tensorflow keras


    【解决方案1】:

    是的,这是正确的。 L2 正则化项已包含在交叉熵损失中,因此您只需将额外损失添加到交叉熵损失中即可。

    希望我的回答有帮助。

    【讨论】:

      猜你喜欢
      • 2021-01-22
      • 2018-06-25
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-12-05
      相关资源
      最近更新 更多