【问题标题】:PyTorch equivalent of a Tensorflow loss functionPyTorch 等效于 Tensorflow 损失函数
【发布时间】:2019-12-05 03:36:19
【问题描述】:

我试图使用 PyTorch 框架重新实现 TensorFlow 代码。下面我包含了 TF 示例代码和我的 PyT 解释,目标大小为 (Batch, 9, 9, 4),网络输出大小为 (Batch, 9, 9, 4)

TensorFlow 实现:

loss = tf.nn.softmax_cross_entropy_with_logits(labels=target, logits=output)
loss = tf.matrix_band_part(loss, 0, -1) - tf.matrix_band_part(loss, 0, 0)

PyTorch 实现:

output = torch.tensor(output, requires_grad=True).view(-1, 4)
target = torch.tensor(target).view(-1, 4).argmax(1)

loss = torch.nn.CrossEntropyLoss(reduction='none')
my_loss = loss(output, target).view(-1,9,9)

对于 PyTorch 的实现,我不确定如何实现 tf.matrix_band_part。我正在考虑定义一个掩码,但我不确定这是否会损害反向传播。我知道torch.triu,但这个函数不适用于二维以上的张量。

【问题讨论】:

  • torch.triu 可以很好地处理批次(如per docs)。你可以通过einsum:torch.einsum('...ii->...i', A)获取对角元素。
  • btw 掩码不会伤害反向传播
  • 你是对的,他们最近添加了这个。我升级到 PyTorch 1.3.0 并且工作正常。您可以将其写为响应。

标签: tensorflow pytorch


【解决方案1】:

因为(至少)1.2.0 版torch.triu 可以很好地处理批次(如per docs)。

你可以通过einsum:torch.einsum('...ii->...i', A)获取对角元素。

应用遮罩不会伤害反向传播。您可以将其视为投影(这显然适用于反向传播)。

【讨论】:

    猜你喜欢
    • 2021-08-01
    • 1970-01-01
    • 2020-03-28
    • 2022-01-15
    • 2020-10-30
    • 2022-10-12
    • 2019-05-27
    • 2021-10-05
    • 2019-02-03
    相关资源
    最近更新 更多