【发布时间】:2022-08-07 16:56:48
【问题描述】:
我试图做的可能很简单,但我是新手,不知道如何开始。
我试图看看我的训练模型如何预测 y 的单个实例以及预测和实际 y 的列表。
看来我错过了几个步骤,我不确定如何实现 predict_step,这是我目前拥有的:
mutag = ptgeom.datasets.TUDataset(root=\'.\', name=\'MUTAG\')
train_idx, test_idx = train_test_split(range(len(mutag)), stratify=[m.y[0].item() for m in mutag], test_size=0.25)
train_loader = ptgeom.loader.DataLoader(mutag[train_idx], batch_size=32, shuffle=True)
test_loader = ptgeom.loader.DataLoader(mutag[test_idx], batch_size=32)
class MUTAGClassifier(ptlight.LightningModule):
def __init__(self):
# The model is just GCNConv --> GCNConv --> graph pooling --> Dropout --> Linear
super().__init__()
self.gc1 = ptgeom.nn.GCNConv(7, 256)
self.gc2 = ptgeom.nn.GCNConv(256, 256)
self.linear = torch.nn.Linear(256, 1)
def forward(self, x, edge_index=None, batch=None, edge_weight=None):
# Note: \"edge_weight\" is not used for training, but only for the explainability part
if edge_index == None:
x, edge_index, batch = x.x, x.edge_index, x.batch
x = F.relu(self.gc1(x, edge_index, edge_weight))
x = F.relu(self.gc2(x, edge_index, edge_weight))
x = ptgeom.nn.global_mean_pool(x, batch)
x = F.dropout(x)
x = self.linear(x)
return x
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def training_step(self, batch, _):
y_hat = self.forward(batch.x, batch.edge_index, batch.batch)
loss = F.binary_cross_entropy_with_logits(y_hat, batch.y.unsqueeze(1).float())
self.log(\"train_loss\", loss)
self.log(\"train_accuracy\", accuracy(y_hat, batch.y.unsqueeze(1)), prog_bar=True, batch_size=32)
return loss
def validation_step(self, batch, _):
x, edge_index, batch_idx = batch.x, batch.edge_index, batch.batch
y_hat = self.forward(x, edge_index, batch_idx)
self.log(\"val_accuracy\", accuracy(y_hat, batch.y.unsqueeze(1)), prog_bar=True, batch_size=32)
checkpoint_callback = ptlight.callbacks.ModelCheckpoint(
dirpath=\'./checkpoints/\',
filename=\'gnn-{epoch:02d}\',
every_n_epochs=50,
save_top_k=-1)
trainer = ptlight.Trainer(max_epochs=200, callbacks=[checkpoint_callback])
trainer.fit(gnn, train_loader, test_loader)