【发布时间】:2021-02-09 13:24:02
【问题描述】:
我有一个中等大小的张量 x。在这个中等大小的张量上,应用一个计算量大的函数(前向和后向)q 来获得另一个中等大小的张量 y。
我使用 y 评估许多函数 f 产生一个标量,它们的计算成本不是特别高,但是使用大的内部状态会导致 Torch 的计算图很大。
现在我想通过以下方式计算 x 上的梯度
y = q(x)
for f in functions
res += f(y)
res.backward()
这个实现的问题是所有函数 f 的图都被保留了。这会导致总内存使用量激增。
另一种可能性是计算
y = q(x)
for f in functions
partial = f(y)
partial.backward(retain_graph = True)
优点是每次函数评估后结果超出范围并释放图形,从而节省大量内存。然而,在这种情况下,函数 q(x) 会被多次反向评估,这非常耗时。
在理想情况下,我希望首先使用类似于第二个示例的代码计算 y 的梯度,然后只向后计算一次 q 以获得 x 的梯度。使用 PyTorch 的正确方法是什么?
【问题讨论】:
标签: pytorch