【发布时间】:2020-04-30 06:07:51
【问题描述】:
我在询问关于 NLLLoss 损失函数的 C 类。
文档说明:
负对数似然损失。用 C 类训练分类问题很有用。
基本上,在那之后的一切都取决于您知道什么是 C 类,我以为我知道什么是 C 类,但文档对我来说没有多大意义。特别是当它描述(N, C) where C = number of classes 的预期输入时。这就是我感到困惑的地方,因为我认为 C 类仅引用 输出。我的理解是 C 类是分类的一个热门向量。我经常在教程中发现NLLLoss 经常与LogSoftmax 配对来解决分类问题。
我希望在以下示例中使用 NLLLoss:
# Some random training data
input = torch.randn(5, requires_grad=True)
print(input) # tensor([-1.3533, -1.3074, -1.7906, 0.3113, 0.7982], requires_grad=True)
# Build my NN (here it's just a LogSoftmax)
m = nn.LogSoftmax(dim=0)
# Train my NN with the data
output = m(input)
print(output) # tensor([-2.8079, -2.7619, -3.2451, -1.1432, -0.6564], grad_fn=<LogSoftmaxBackward>)
loss = nn.NLLLoss()
print(loss(output, torch.tensor([1, 0, 0])))
以上在最后一行引发以下错误:
ValueError:预期 2 个或更多维度(得到 1 个)
我们可以忽略这个错误,因为很明显我不明白我在做什么。这里我将解释我对上述源代码的意图。
input = torch.randn(5, requires_grad=True)
随机一维数组与[1, 0, 0] 的一个热向量配对以进行训练。我正在尝试对一个十进制数的热向量进行二进制位。
m = nn.LogSoftmax(dim=0)
LogSoftmax 的文档说输出将与输入具有相同的形状,但我只看到了 LogSoftmax(dim=1) 的示例,因此我一直在尝试使这项工作正常工作,因为我做不到找一个相关的例子。
print(loss(output, torch.tensor([1, 0, 0])))
所以现在我有了 NN 的输出,我想知道我的分类 [1, 0, 0] 的损失。在这个例子中,数据是什么并不重要。我只想要一个代表分类的热向量的损失。
此时,我在尝试解决与预期输出和输入结构相关的损失函数中的错误时遇到了困难。我尝试在输出和输入上使用view(...) 来修复形状,但这只会让我遇到其他错误。
所以这又回到了我最初的问题,我将展示文档中的示例来解释我的困惑:
m = nn.LogSoftmax(dim=1)
loss = nn.NLLLoss()
input = torch.randn(3, 5, requires_grad=True)
train = torch.tensor([1, 0, 4])
print('input', input) # input tensor([[...],[...],[...]], requires_grad=True)
output = m(input)
print('train', output, train) # tensor([[...],[...],[...]],grad_fn=<LogSoftmaxBackward>) tensor([1, 0, 4])
x = loss(output, train)
同样,LogSoftmax 上有 dim=1,这让我很困惑,因为请查看 input 数据。这是一个 3x5 张量,我迷路了。
这是NLLLoss 函数的第一个输入的文档:
输入:(N, C)(N,C) 其中 C = 类数
输入按类数分组?
那么张量输入的每一行都与训练张量的每一元素相关联?
如果我改变输入张量的第二个维度,那么什么都不会发生,我不明白发生了什么。
input = torch.randn(3, 100, requires_grad=True)
# 3 x 100 still works?
所以我不明白这里的 C 类是什么,我认为 C 类是一种分类(如标签),仅对 NN 的输出有意义。
我希望你能理解我的困惑,因为神经网络的输入形状不应该独立于用于分类的一个热向量的形状吗?
代码示例和文档都说输入的形状是由分类数定义的,我不太明白为什么。
我试图研究文档和教程以了解我所缺少的内容,但在几天无法超越这一点后,我决定问这个问题。这让我感到羞愧,因为我认为这将是更容易学习的事情之一。
【问题讨论】:
标签: python machine-learning neural-network pytorch