【问题标题】:Shape of target and predictions tensors in PyTorch loss functionsPyTorch 损失函数中目标和预测张量的形状
【发布时间】:2020-06-08 09:17:28
【问题描述】:

我对@9​​87654322@ 中张量的输入形状感到困惑。 我正在尝试为文本序列实现一个简单的自动编码器。我的问题的核心可以用下面的代码来说明

predictions = torch.rand(2, 3, 4)
target = torch.rand(2, 3)
print(predictions.shape)
print(target.shape)
nn.CrossEntropyLoss(predictions.transpose(1, 2), target)

在我的例子中,预测的形状为(time_step, batch_size, vocabulary_size),而目标的形状为(time_step, batch_size)。接下来我根据description 转置预测,这表示预测的第二维应该是类的数量——在我的例子中是词汇大小。代码返回错误RuntimeError: bool value of Tensor with more than one value is ambiguous。有人可以请教我如何使用该死的东西吗?提前谢谢!

【问题讨论】:

    标签: python pytorch recurrent-neural-network


    【解决方案1】:

    您不是在调用损失函数,而是在构建它。 nn.CrossEntropyLoss 构造函数的签名是:

    nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
    

    您将预测设置为weight,将目标设置为size_average, 其中weight 是类的可选重新缩放,size_average 已弃用,但需要一个布尔值。目标是大小为 [2, 3] 的张量,不能转换为布尔值。

    您需要先创建损失函数,因为您不使用构造函数的任何可选参数,因此您无需指定任何参数。

    # Create the loss function
    cross_entropy = nn.CrossEntropyLoss()
    
    # Call it to calculate the loss for your data
    loss = cross_entropy(predictions.transpose(1, 2), target)
    

    也可以直接使用功能版nn.functional.cross_entropy

    import torch.nn.functional as F
    
    loss = F.cross_entropy(predictions.transpose(1, 2), target)
    

    类版相比函数版的优势在于,额外参数只需指定一次(如weight),而不必每次都手动提供。

    关于张量的维度,批次大小必须是第一个维度,因为损失是批次中每个元素的平均值,因此您有大小为 [batch_size] 的损失张量。如果您使用reduction="none",您将获得批次中每个元素的这些损失,但默认情况下 (reduction="mean") 会返回这些损失的平均值。如果平均值是跨时间步而不是批次取的,那么结果会有所不同。

    最后,目标必须是类索引,这意味着它们的类型必须是 torch.long 而不是 torch.float。在这个随机选择的示例中,您可以使用torch.randint 创建随机类。

    predictions = torch.rand(2, 3, 4)
    target = torch.randint(4, (2, 3))
    
    # Reorder the dimensions
    # From: [time_step, batch_size, vocabulary_size]
    # To: [batch_size, vocabulary_size, time_step]
    predictions = predictions.permute(1, 2, 0)
    # From: [time_step, batch_size]
    # To: [batch_size, time_step]
    target = target.transpose(0, 1)
    
    F.cross_entropy(predictions, target)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2021-08-08
      • 2021-05-30
      • 2020-12-02
      • 2019-02-21
      • 2022-01-16
      • 2019-11-12
      • 1970-01-01
      相关资源
      最近更新 更多