【发布时间】:2021-08-24 23:15:41
【问题描述】:
我正在 pytorch 中构建一个具有多个网络的模型。例如,让我们考虑netA 和netB。在损失函数中,我需要使用组合 netA(netB)。在优化的不同部分,我需要计算loss_func(netA(netB)) 的梯度,仅相对于netA 的参数,在另一种情况下,我需要计算netB 的参数的梯度。应该如何解决这个问题?
我的方法:在使用netA的参数计算梯度的情况下,我使用loss_func(netA(netB.detach()))。
如果我写loss_func(netA(netB).detach()),似乎netA 和netB 的两个参数都是分离的。
我尝试使用loss_func(netA.detach(netB)) 来仅分离netA 的参数,但它不起作用。 (我收到netA 没有属性分离的错误。)
【问题讨论】:
标签: neural-network pytorch gradient-descent detach