【问题标题】:Pytorch loss function dimensions do not matchPytorch 损失函数尺寸不匹配
【发布时间】:2019-08-13 14:16:16
【问题描述】:

我正在尝试使用批量训练运行词嵌入,如下所示。

def forward(self, inputs):
    print(inputs.shape)
    embeds = self.embeddings(inputs)
    print(embeds.shape)
    out = self.linear1(embeds)
    print(out.shape)
    out = self.activation_function1(out)
    print(out.shape)
    out = self.linear2(out).cuda()
    print(out.shape)
    out = self.activation_function2(out)
    print(out.shape)
    return out.cuda()

这里,我使用的是上下文大小 4,批量大小 32,嵌入大小 50,隐藏层大小 64,词汇大小 9927

“形状”函数的输出是

print(inputs.shape) ----> torch.Size([4, 32])

print(embeds.shape) ----> torch.Size([4, 32, 50])

print(out.shape) ----> torch.Size([4, 32, 64])

print(out.shape) ----> torch.Size([4, 32, 64])

print(out.shape) ----> torch.Size([4, 32, 9927])

print(out.shape) ----> torch.Size([4, 32, 9927])

这些形状是否正确?我很困惑。

另外,当我训练时,它会返回一个错误:

def train(epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader, 0):
    optimizer.zero_grad()
    output = model(torch.stack(data))
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

我在“损失 = 标准(输出,目标)”这一行中遇到错误。它显示“预期输入 batch_size (4) 与目标 batch_size (32) 匹配”。我的“前进”功能形状是否正确?我对批量训练不是很熟悉。如何使尺寸匹配?

-------编辑:在下面发布初始化代码-----

  def __init__(self, vocab_size, embedding_dim):
    super(CBOW, self).__init__()
    self.embeddings = nn.Embedding(vocab_size, embedding_dim)
    self.linear1 = nn.Linear(embedding_dim, 64)
    self.activation_function1 = nn.ReLU()
    self.linear2 = nn.Linear(64, vocab_size)
    self.activation_function2 = nn.LogSoftmax(dim = -1)

【问题讨论】:

    标签: python neural-network pytorch


    【解决方案1】:

    torch.nn.Linearforward 方法需要批量大小作为第一个参数。

    您将其作为第二个(第一个是时间步)提供,请使用 permute(1, 0, 2) 将其作为第一个。

    此外,线性层通常采用 2D 输入,第一个是批处理,第二个是输入的维度。你的是 3d 因为文字(我假设),也许你想使用循环神经网络(例如torch.nn.LSTM)?

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-05-28
      • 1970-01-01
      • 2020-02-19
      • 2012-09-11
      • 2014-12-14
      • 2022-09-27
      • 1970-01-01
      • 2022-01-15
      相关资源
      最近更新 更多