【发布时间】:2021-07-01 03:30:09
【问题描述】:
我在 BERT 之上有一个分类器,我想查看创建 ROC 曲线的预测概率。我如何获得预测概率?预测概率将用于计算 ROC 曲线的 TPR FPR 和阈值。
这是代码
class BertBinaryClassifier(nn.Module):
def __init__(self, dropout=0.1):
super(BertBinaryClassifier, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(768, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, tokens, masks=None):
_, pooled_output = self.bert(tokens, attention_mask=masks, output_all_encoded_layers=False)
dropout_output = self.dropout(pooled_output)
linear_output = self.linear(dropout_output)
prediction = self.sigmoid(linear_output)
return prediction
# Config setting
BATCH_SIZE = 4
EPOCHS = 5
# Making dataloaders
train_dataset = torch.utils.data.TensorDataset(train_tokens_tensor, train_masks_tensor, train_y_tensor)
train_sampler = torch.utils.data.RandomSampler(train_dataset)
train_dataloader = torch.utils.data.DataLoader(train_dataset, sampler=train_sampler, batch_size=BATCH_SIZE)
test_dataset = torch.utils.data.TensorDataset(test_tokens_tensor, test_masks_tensor, test_y_tensor)
test_sampler = torch.utils.data.SequentialSampler(test_dataset)
test_dataloader = torch.utils.data.DataLoader(test_dataset, sampler=test_sampler, batch_size=BATCH_SIZE)
bert_clf = BertBinaryClassifier()
bert_clf = bert_clf.cuda()
#wandb.watch(bert_clf)
optimizer = torch.optim.Adam(bert_clf.parameters(), lr=3e-6)
# training
for epoch_num in range(EPOCHS):
bert_clf.train()
train_loss = 0
for step_num, batch_data in enumerate(train_dataloader):
token_ids, masks, labels = tuple(t for t in batch_data)
token_ids, masks, labels = token_ids.to(device), masks.to(device), labels.to(device)
preds = bert_clf(token_ids, masks)
loss_func = nn.BCELoss()
batch_loss = loss_func(preds, labels)
train_loss += batch_loss.item()
bert_clf.zero_grad()
batch_loss.backward()
optimizer.step()
#wandb.log({"Training loss": train_loss})
print('Epoch: ', epoch_num + 1)
print("\r" + "{0}/{1} loss: {2} ".format(step_num, len(train_data) / BATCH_SIZE, train_loss / (step_num + 1)))
# evaluating on test
bert_clf.eval()
bert_predicted = []
all_logits = []
probs=[]
with torch.no_grad():
test_loss = 0
for step_num, batch_data in enumerate(test_dataloader):
token_ids, masks, labels = tuple(t for t in batch_data)
token_ids, masks, labels = token_ids.to(device), masks.to(device), labels.to(device)
logits = bert_clf(token_ids, masks)
pr=logits.ravel()
probs+=pr
loss_func = nn.BCELoss()
loss = loss_func(logits, labels)
test_loss += loss.item()
numpy_logits = logits.cpu().detach().numpy()
#print(numpy_logits)
#wandb.log({"Testing loss": test_loss})
bert_predicted += list(numpy_logits[:, 0] > 0.5)
all_logits += list(numpy_logits[:, 0])
我能够得到预测分数来计算准确率或 f1 分数。但不是创建 ROC 曲线的概率。 谢谢
【问题讨论】:
-
网络的输出是概率(你称之为logits但不包含logits)。这就是你应该用来获得 ROC 曲线的方法
标签: python python-3.x pytorch bert-language-model huggingface-transformers