【问题标题】:Computing gradients only for the front-end network in Pytorch仅针对 Pytorch 中的前端网络计算梯度
【发布时间】:2021-07-07 10:04:02
【问题描述】:

我有一个非常简单的问题。

假设我有两个网络要训练(即 net1、net2)。 net1 的输出将在训练时输入 net2。 就我而言,我只想更新 net1:

optimizer=Optimizer(net1.parameters(), **kwargs)
loss=net2(net1(x))
loss.backward()
optimizer.step()

虽然这将实现我的目标,但它占用了太多的冗余内存,因为这将计算 net2 的梯度(导致 OOM 错误)。 因此我尝试了几次尝试来解决这个问题:

  1. torch.no_grad:
z=net1(x)
with torch.no_grad():
    loss=net2(z)

没有引发 OOM,但删除了所有渐变,包括来自 net1 的渐变。

  1. requires_grad=False:
net2.requires_grad=False
loss=net2(net1(x))

引发 OOM。

  1. 分离():
z=net1(x)
loss=net2(z).detach()

没有引发 OOM,但删除了所有渐变,包括来自 net1 的渐变。

  1. eval():
net2.eval()
loss=net2(net1(x))

引发 OOM。

有没有什么方法可以只计算前端网络(net1)的梯度以提高内存效率? 任何建议将不胜感激。

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    首先让我们试着理解为什么你的方法不起作用。

    1. 此上下文管理器禁用所有梯度计算。
    2. 由于net1 需要渐变,后续的requires_grad=False 将被忽略。
    3. 如果您在该状态下分离,这意味着梯度计算已经停止在那里
    4. eval 只是将 net2 设置为 eval 模式,根本不影响梯度计算。

    根据您的架构,OOM 错误可能已经来自于将所有中间值保存在您的计算图中(这在 CNN 中通常是一个问题),或者它可能来自于必须存储梯度(在完全连接的网络中更常见)。

    您可能正在寻找所谓的“检查点”,您甚至不必自己实现,您可以使用 pytorch 的检查点 API,查看 documentation

    这基本上可以让您分别计算和处理net1net2 的梯度。请注意,您确实需要所有梯度信息通过net2,否则您无法计算梯度wrt。 net1!

    【讨论】:

    • 你说得对。我需要所有梯度信息来训练 net1。 torch.utils.checkpoint 非常适合我!非常感谢
    猜你喜欢
    • 2021-04-01
    • 1970-01-01
    • 2019-09-24
    • 2020-05-12
    • 1970-01-01
    • 1970-01-01
    • 2023-03-22
    • 2021-09-11
    • 1970-01-01
    相关资源
    最近更新 更多