【问题标题】:Find y prediction using pytorch lightning使用 pytorch 闪电查找 y 预测
【发布时间】:2022-08-07 16:56:48
【问题描述】:

我试图做的可能很简单,但我是新手,不知道如何开始。
我试图看看我的训练模型如何预测 y 的单个实例以及预测和实际 y 的列表。 看来我错过了几个步骤,我不确定如何实现 predict_step,这是我目前拥有的:


mutag = ptgeom.datasets.TUDataset(root=\'.\', name=\'MUTAG\')

train_idx, test_idx = train_test_split(range(len(mutag)), stratify=[m.y[0].item() for m in mutag], test_size=0.25)

train_loader = ptgeom.loader.DataLoader(mutag[train_idx], batch_size=32, shuffle=True)
test_loader = ptgeom.loader.DataLoader(mutag[test_idx], batch_size=32)

class MUTAGClassifier(ptlight.LightningModule):
    
    def __init__(self):
      # The model is just GCNConv --> GCNConv --> graph pooling --> Dropout --> Linear
      super().__init__()
      self.gc1 = ptgeom.nn.GCNConv(7, 256)
      self.gc2 = ptgeom.nn.GCNConv(256, 256)
      self.linear = torch.nn.Linear(256, 1)

    def forward(self, x, edge_index=None, batch=None, edge_weight=None):
      # Note: \"edge_weight\" is not used for training, but only for the explainability part
      if edge_index == None:
        x, edge_index, batch = x.x, x.edge_index, x.batch
      x = F.relu(self.gc1(x, edge_index, edge_weight))
      x = F.relu(self.gc2(x, edge_index, edge_weight))
      x = ptgeom.nn.global_mean_pool(x, batch)
      x = F.dropout(x)
      x = self.linear(x)
      return x
 
    def configure_optimizers(self):
      optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
      return optimizer

    def training_step(self, batch, _):
      y_hat = self.forward(batch.x, batch.edge_index, batch.batch)
      loss = F.binary_cross_entropy_with_logits(y_hat, batch.y.unsqueeze(1).float())
      self.log(\"train_loss\", loss)
      self.log(\"train_accuracy\", accuracy(y_hat, batch.y.unsqueeze(1)), prog_bar=True, batch_size=32)
      return loss

    def validation_step(self, batch, _):
        x, edge_index, batch_idx = batch.x, batch.edge_index, batch.batch
        y_hat = self.forward(x, edge_index, batch_idx)
        self.log(\"val_accuracy\", accuracy(y_hat, batch.y.unsqueeze(1)), prog_bar=True, batch_size=32)


checkpoint_callback = ptlight.callbacks.ModelCheckpoint(
    dirpath=\'./checkpoints/\',
    filename=\'gnn-{epoch:02d}\',
    every_n_epochs=50,
    save_top_k=-1)

trainer = ptlight.Trainer(max_epochs=200, callbacks=[checkpoint_callback])

trainer.fit(gnn, train_loader, test_loader)

 

    标签: pytorch-lightning


    【解决方案1】:

    对于 MNIST 的具体情况,您可以这样做:

    class MNISTModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.l1 = torch.nn.Linear(28 * 28, 10)
    
        def forward(self, x):
            return torch.relu(self.l1(x.view(x.size(0), -1)))
    
        def training_step(self, batch, batch_nb):
            x, y = batch
            loss = F.cross_entropy(self(x), y)
            return loss
    
        def predict_step(self, batch, batch_nb):
            x, y = batch
            # this calls forward
            y_hat = self(x)
            return y, y_hat
    
        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters(), lr=0.02)
    

    那么你也能:

    x, y = next(iter(train_loader))            # get first batch in dataloader
    x = x[0].unsqueeze(0)                      # get first input instance in batch
    y = y[0].unsqueeze(0)                      # get first gt instance in batch
    y_hat = trainer.predict(mnist_model, x)    # predict
    y_hat = torch.cat(y_hat)                   # list to tensor
    y_hat = torch.argmax(y_hat, dim=1)
    print(list(y))
    print(list(y_hat))
    

    【讨论】:

      猜你喜欢
      • 2021-04-24
      • 2021-07-05
      • 1970-01-01
      • 2021-03-03
      • 2021-09-22
      • 2023-01-23
      • 2023-03-10
      • 2021-05-06
      • 2021-02-28
      相关资源
      最近更新 更多