【发布时间】:2021-01-27 09:33:43
【问题描述】:
这是我使用 Pytorch 进行图像分类的代码,但我无法获得正确的准确性。 精度超过100,谁能帮我找出错误。
def trained_model(criterion, optimizer, epochs=5):
epoch_loss = 0.0
epoch_accuracy = 0
running_loss = 0
running_accuracy = 0
total = 0
for epoch in range(epochs):
print('epoch : {}/{}'.format(epoch+1, epochs))
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
_, predictions = torch.max(outputs, dim=1)
loss.backward()
optimizer.step()
running_loss += loss.item()
running_accuracy += torch.sum(predictions == labels.data)
epoch_loss = running_loss / len(train_dataset)
epoch_accuracy = running_accuracy / len(train_dataset)
print('Loss:{:.4f} , Accuracy : {:.4f} '.format(epoch_loss, epoch_accuracy))
return model
【问题讨论】:
-
请将代码添加为格式化文本,而不是图像
-
您的标签是单热编码还是索引?
-
@couka 谢谢你的建议,我不知道,因为我是新手。
-
@Ivan Thnak 你!它有帮助。
标签: pytorch data-science image-classification