【问题标题】:How to tranform the following TensorFlow code to PyTorch?如何将以下 TensorFlow 代码转换为 PyTorch?
【发布时间】:2021-01-13 18:00:44
【问题描述】:

我正在寻找 GAN 模型的损失函数,然后 this one 出来了:

gd_loss = tf.reduce_sum(tf.square(tf.abs(dx_real) - tf.abs(dx_fake))) + \
              tf.reduce_sum(tf.square(tf.abs(dy_real) - tf.abs(dy_fake))) + \
              tf.reduce_sum(tf.square(tf.abs(dz_real) - tf.abs(dz_fake)))

但我想将以下内容转换为 PyTorch,因为我使用的是 PyTorch 张量:

dx_real = t_target_image[:, 1:, :, :, :] - t_target_image[:, :-1, :, :, :]

t_target_image 是 TensorFlow 中的张量。

我该怎么做?

【问题讨论】:

    标签: tensorflow pytorch tensor code-conversion


    【解决方案1】:

    PyTorch 擅长对常见的矩阵运算坚持纯 Python 语法,所以你可以这样做

    gd_loss = ((dx_real.abs() - dx_fake.abs())**2).sum() + \
                  ((dy_real.abs() - dy_fake.abs())**2).sum() + \
                  ((dz_real.abs() - dz_fake.abs())**2).sum()
    

    如果您的问题是关于将 tensorflow 张量转换为 pytorch 张量,则应先转换为 numpy,然后使用 torch.as_tensor 转换为 pytorch。

    【讨论】:

    • 我的问题是关于第二部分 dx_real = t_target_image[:, 1:, :, :, :] - t_target_image[:, :-1, :, :, :] 。我放在这里的示例代码仅从 3D TensorFlow 张量获取 x 轴,我想知道可以使用 3D PyTorch 张量做同样的事情。
    • 是的,这种 numpy 风格的索引也应该适用于 pytorch 张量。您是否遇到了特定的错误消息?
    • 它没有给出错误,但也没有给出值。当我尝试计算这个值torch.mean((torch.abs(dx_real) 时,返回的值是tensor(nan, device='cuda:0') 所以我假设这个dx_real = t_target_image[:, 1:, :, :, :] - t_target_image[:, :-1, :, :, :] 是错误的。
    • 这是一个完全不同的问题,所以我建议你发布另一个关于具体问题的问题
    【解决方案2】:

    解决办法是:

    dx_real = real[:, :, :, 1:, :] - real[:, :, :, :-1, :]
    dy_real = real[:, :, 1:, :, :] - real[:, :, :-1, :, :]
    dz_real = real[:, :, :, :, 1:] - real[:, :, :, :, :-1]
    dx_fake = fake[:, :, :, 1:, :] - fake[:, :, :, :-1, :]
    dy_fake = fake[:, :, 1:, :, :] - fake[:, :, :-1, :, :]
    dz_fake = fake[:, :, :, :, 1:] - fake[:, :, :, :, :-1]
    gd_loss = torch.sum(torch.pow(torch.abs(dx_real) - torch.abs(dx_fake),2),dim=(2,3,4)) + \
              torch.sum(torch.pow(torch.abs(dy_real) - torch.abs(dy_fake),2),dim=(2,3,4)) + \
              torch.sum(torch.pow(torch.abs(dz_real) - torch.abs(dz_fake),2),dim=(2,3,4))
    return torch.sum(gd_loss)`
    

    【讨论】:

      猜你喜欢
      • 2019-08-03
      • 1970-01-01
      • 2021-05-23
      • 2019-10-03
      • 2022-11-20
      • 2021-08-28
      • 1970-01-01
      • 1970-01-01
      • 2018-04-27
      相关资源
      最近更新 更多