【发布时间】:2019-12-05 03:36:19
【问题描述】:
我试图使用 PyTorch 框架重新实现 TensorFlow 代码。下面我包含了 TF 示例代码和我的 PyT 解释,目标大小为 (Batch, 9, 9, 4),网络输出大小为 (Batch, 9, 9, 4)
TensorFlow 实现:
loss = tf.nn.softmax_cross_entropy_with_logits(labels=target, logits=output)
loss = tf.matrix_band_part(loss, 0, -1) - tf.matrix_band_part(loss, 0, 0)
PyTorch 实现:
output = torch.tensor(output, requires_grad=True).view(-1, 4)
target = torch.tensor(target).view(-1, 4).argmax(1)
loss = torch.nn.CrossEntropyLoss(reduction='none')
my_loss = loss(output, target).view(-1,9,9)
对于 PyTorch 的实现,我不确定如何实现 tf.matrix_band_part。我正在考虑定义一个掩码,但我不确定这是否会损害反向传播。我知道torch.triu,但这个函数不适用于二维以上的张量。
【问题讨论】:
-
btw 掩码不会伤害反向传播
-
你是对的,他们最近添加了这个。我升级到 PyTorch 1.3.0 并且工作正常。您可以将其写为响应。
标签: tensorflow pytorch