【问题标题】:Confusion matrix in Pytorch LightningPytorch Lightning 中的混淆矩阵
【发布时间】:2021-12-08 08:26:57
【问题描述】:

我正在使用 Pytorch Lightning 在 CIFAR10 数据集上运行 Alexnet,这是我的模型:

class SelfSupervisedModel(pl.LightningModule):
    def __init__(self, hparams=None, num_classes=10, batch_size=128):
        super(SelfSupervisedModel, self).__init__()

        self.batch_size = batch_size
        self.loss_fn = nn.CrossEntropyLoss()
        self.hparams["lr"] = ModelHelper.Hyperparam.Learning_rate

        self.model = torchvision.models.alexnet(pretrained=False)

    def forward(self, x):
        return self.model(x)

    def training_step(self, train_batch, batch_idx):
        inputs, targets = train_batch
        predictions = self(inputs)
        loss = self.loss_fn(predictions, targets)
        return {'loss': loss}

    def validation_step(self, test_batch, batch_idx):
        inputs, targets = test_batch
        predictions = self(inputs)
        val_loss = self.loss_fn(predictions, targets)
        _, preds = tf.max(predictions, 1)
        acc = tf.sum(preds == targets.data) / (targets.shape[0] * 1.0)
        return {'val_loss': val_loss, 'val_acc': acc, 'target': targets, 'preds': predictions}

    def validation_epoch_end(self, outputs):
        avg_loss = tf.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = tf.stack([x['val_acc'].float() for x in outputs]).mean()
        logs = {'val_loss': avg_loss, 'val_acc': avg_acc}
        print(f'validation_epoch_end logs => {logs}')

        OutputMatrix.predictions = tf.cat([tmp['preds'] for tmp in outputs])
        OutputMatrix.targets = tf.cat([tmp['target'] for tmp in outputs])
        
        return {'progress_bar': logs}

    def configure_optimizers(self):
      return tf.optim.SGD(self.parameters(), lr=self.hparams["lr"], momentum=0.9)

我将预测值和真实值存储在 OutputMatrix.predictionsOutputMatrix.targets 中,用于生成混淆矩阵,如下所示:

我很确定这不应该是输出。找不到错在哪里。任何帮助将不胜感激。

【问题讨论】:

标签: python machine-learning pytorch conv-neural-network pytorch-lightning


【解决方案1】:

我建议使用Torchmetrics 和内部的log 方法,所以代码可以这样:

class MyModule(LightningModule):

    def __init__(self):
        ...
        self.train_acc = torchmetrics.Accuracy()
        self.valid_acc = torchmetrics.Accuracy()

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        ...
        self.train_acc(preds, y)
        self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)

    def validation_step(self, batch, batch_idx):
        logits = self(x)
        ...
        self.valid_acc(logits, y)
        self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)

您也可以在与 PL integration 相关的文档中找到。

【讨论】:

    猜你喜欢
    • 2020-02-23
    • 2021-04-06
    • 2021-11-30
    • 1970-01-01
    • 2014-07-09
    • 2018-11-09
    • 2015-03-09
    • 2018-01-25
    • 1970-01-01
    相关资源
    最近更新 更多