【问题标题】:How to avoid recalculating a function when we need to backpropagate through it twice?当我们需要通过它进行两次反向传播时,如何避免重新计算函数?
【发布时间】:2020-08-20 13:42:53
【问题描述】:

在 PyTorch 中,我想做如下计算:

l1 = f(x.detach(), y)
l1.backward(retain_graph=True)
l2 = -1*f(x, y.detach())
l2.backward()

其中f 是某个函数,xy 是需要梯度的张量。请注意,xy 可能都是先前使用共享参数的计算的结果(例如,可能是 x=g(z)y=g(w),其中 gnn.Module)。

问题是l1l2 在数字上都是相同的,直到减号,重复计算f(x,y) 两次似乎很浪费。能够计算一次并在结果上应用两次backward 会更好。有没有办法做到这一点?

一种可能性是手动调用autograd.grad 并更新每个nn.Parameter ww.grad 字段。但我想知道是否有更直接和更干净的方法来做到这一点,使用 backward 函数。

【问题讨论】:

  • 我看不到在这种情况下避免两次计算 f 的方法,但您可以通过让 l = l1 + l2 后跟 l.backward() 来避免向后调用两次。

标签: pytorch autograd


【解决方案1】:

我从here得到这个答案。

如果我们确保乘以-1 流经x 的梯度,我们可以计算一次f(x,y),而不分离xy。这可以使用register_hook 来完成:

x.register_hook(lambda t: -t)
l = f(x,y)
l.backward()

以下是证明此方法有效的代码:

import torch

lin = torch.nn.Linear(1, 1, bias=False)
lin.weight.data[:] = 1.0
a = torch.tensor([1.0])
b = torch.tensor([2.0])
loss_func = lambda x, y: (x - y).abs()

# option 1: this is the inefficient option, presented in the original question
lin.zero_grad()
x = lin(a)
y = lin(b)
loss1 = loss_func(x.detach(), y)
loss1.backward(retain_graph=True)
loss2 = -1 * loss_func(x, y.detach())  # second invocation of `loss_func` - not efficient!
loss2.backward()
print(lin.weight.grad)

# option 2: this is the efficient method, suggested in this answer. 
lin.zero_grad()
x = lin(a)
y = lin(b)
x.register_hook(lambda t: -t)
loss = loss_func(x, y)  # only one invocation of `loss_func` - more efficient!
loss.backward()
print(lin.weight.grad)  # the output of this is identical to the previous print, which confirms the method

# option 3 - this should not be equivalent to the previous options, used just for comparison
lin.zero_grad()
x = lin(a)
y = lin(b)
loss = loss_func(x, y)
loss.backward()
print(lin.weight.grad)

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-06-30
    • 2016-06-16
    • 1970-01-01
    • 1970-01-01
    • 2014-04-22
    • 1970-01-01
    相关资源
    最近更新 更多