【问题标题】:AttributeError on variable input of custom loss function in PyTorchPyTorch中自定义损失函数的变量输入上的AttributeError
【发布时间】:2018-05-06 02:03:43
【问题描述】:

我制作了一个自定义损失函数来计算多输出多标签问题的交叉熵 (CE)。在课堂上,我想将我输入的目标变量设置为不需要渐变。我在类外使用预定义函数(取自 pytorch 源代码)在 forward 函数中执行此操作。

    def _assert_no_grad(variable):
        assert not variable.requires_grad

    def forward(self, predicted, target):
        """
        Computes cross entropy between targets and predictions.
        """
        # No gradient over target
        _assert_no_grad(target)

        # Define variables
        p = predicted.clamp(0.01, 0.99)
        t = target.float()

        #Compute cross entropy
        h1 = p.log()*t
        h2 = (1-t)*((1-p).log())
        ce = torch.add(h1, h2)
        ce_out = torch.mean(ce, 1)
        ce_out = torch.mean(ce_out, 0)

        # Save for backward step
        self.save_for_backward(ce_out)

此时,当我在批处理 for 循环中运行代码时(见下文),我收到以下错误:

AttributeError: 'torch.FloatTensor' 对象没有属性 'requires_grad'

这似乎很简单,因为我们应该传递一个 torch.autograd.Variable,但是我已经这样做了,如下面的 sn-p 所示。

for t in range(50):

    print('Epoch {}'.format(t))
    if t > 0:
        print('Loss ->', loss)

    for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
        # Wrap in Variable
        x_in, target = Variable(x_batch), Variable(y_batch)

        predicted = model(x_in)

        # Compute and print loss
        loss = criterion(predicted, target)

        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

要添加注释,我的最终目标是生成一个行为类似于 BCELoss 的类,除了多个标签,而不仅仅是二进制。我觉得我已经浏览了整个 PyTorch 文档,主要是在使用这个和一些论坛条目。 http://pytorch.org/docs/master/notes/extending.html

所以

【问题讨论】:

  • 您的问题不清楚。您的代码中的模型是什么?正如错误所说,浮点张量没有任何名为requires_grad的属性,该属性属于变量。你可以在forward函数中打印predictedtarget的类型,看看它们是否是Variable。

标签: python pytorch


【解决方案1】:

问题出在“target.float()”行中,它将您的 t 变量转换为张量。您可以直接使用 target,在 CE 计算中没有任何问题。

另外,我猜您实际上并不需要“self.save_for_backward(ce_out)”,因为我猜您正在定义 nn.Module 类,它将在内部处理向后传递。

【讨论】:

  • 删除 .float() 解决了这个问题。我也是从 autograd.Function 打来的,nn.Module 解决了很多问题。调用 nn.Module 进行自定义损失似乎有点违反直觉。
猜你喜欢
  • 2019-05-27
  • 2021-05-04
  • 1970-01-01
  • 2021-11-13
  • 2020-02-25
  • 2018-09-24
  • 2021-02-20
  • 2018-08-26
  • 2020-05-11
相关资源
最近更新 更多