【发布时间】:2020-03-14 08:17:09
【问题描述】:
我对 Pytorch 的分类交叉熵损失的计算有疑问。 我编写了这个简单的代码 sn-p 并且因为我使用输出张量的 argmax 作为目标,所以我无法理解为什么损失仍然很高。
import torch
import torch.nn as nn
ce_loss = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
targets = torch.argmax(output, dim=1)
loss = ce_loss(outputs, targets)
print(loss)
感谢您帮助理解它。 最好的祝福 杰罗姆
【问题讨论】:
-
高是什么意思?查看我的答案,了解如何计算损失。
标签: pytorch