【问题标题】:Partial Backwards in PyTorch GraphPyTorch 图中的部分向后
【发布时间】:2021-02-09 13:24:02
【问题描述】:

我有一个中等大小的张量 x。在这个中等大小的张量上,应用一个计算量大的函数(前向和后向)q 来获得另一个中等大小的张量 y。

我使用 y 评估许多函数 f 产生一个标量,它们的计算成本不是特别高,但是使用大的内部状态会导致 Torch 的计算图很大。

现在我想通过以下方式计算 x 上的梯度

y = q(x)

for f in functions
    res += f(y)

res.backward()

这个实现的问题是所有函数 f 的图都被保留了。这会导致总内存使用量激增。

另一种可能性是计算

y = q(x)

for f in functions
    partial = f(y)
    partial.backward(retain_graph = True)

优点是每次函数评估后结果超出范围并释放图形,从而节省大量内存。然而,在这种情况下,函数 q(x) 会被多次反向评估,这非常耗时。

在理想情况下,我希望首先使用类似于第二个示例的代码计算 y 的梯度,然后只向后计算一次 q 以获得 x 的梯度。使用 PyTorch 的正确方法是什么?

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    我认为这是实现它的方法:

    y = q(x)
    z = y.detach()
    z.requires_grad_(True)
    
    for f in functions:
        partial = f(y)
        partial.backward(retain_graph = True)
    y.backward(z.grad)
    

    您在 z 中累积所有梯度,其中 y 但在另一个计算图中,然后您在第一个图中传播这些梯度 (z.grad)

    【讨论】:

      猜你喜欢
      • 2019-12-06
      • 2017-10-31
      • 2020-07-20
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-06-27
      • 2020-05-23
      • 1970-01-01
      相关资源
      最近更新 更多