【发布时间】:2021-07-07 10:04:02
【问题描述】:
我有一个非常简单的问题。
假设我有两个网络要训练(即 net1、net2)。 net1 的输出将在训练时输入 net2。 就我而言,我只想更新 net1:
optimizer=Optimizer(net1.parameters(), **kwargs)
loss=net2(net1(x))
loss.backward()
optimizer.step()
虽然这将实现我的目标,但它占用了太多的冗余内存,因为这将计算 net2 的梯度(导致 OOM 错误)。 因此我尝试了几次尝试来解决这个问题:
- torch.no_grad:
z=net1(x)
with torch.no_grad():
loss=net2(z)
没有引发 OOM,但删除了所有渐变,包括来自 net1 的渐变。
- requires_grad=False:
net2.requires_grad=False
loss=net2(net1(x))
引发 OOM。
- 分离():
z=net1(x)
loss=net2(z).detach()
没有引发 OOM,但删除了所有渐变,包括来自 net1 的渐变。
- eval():
net2.eval()
loss=net2(net1(x))
引发 OOM。
有没有什么方法可以只计算前端网络(net1)的梯度以提高内存效率? 任何建议将不胜感激。
【问题讨论】:
标签: pytorch