【问题标题】:Why detach needs to be called on variable in this example?为什么在这个例子中需要在变量上调用分离?
【发布时间】:2022-04-17 14:17:57
【问题描述】:

我正在查看这个示例 - https://github.com/pytorch/examples/blob/master/dcgan/main.py,我有一个基本问题。

fake = netG(noise)
label = Variable(label.fill_(fake_label))
output = netD(fake.detach()) # detach to avoid training G on these labels
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake
optimizerD.step()

我明白为什么我们在变量fake 上调用detach(),这样就不会为生成器参数计算梯度。我的问题是,因为optimizerD.step() 将只更新与鉴别器相关的参数,这有关系吗?

OptimizerD 定义为: optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

此外,在下一步我们将更新生成器的参数时,在此之前我们将调用netG.zero_grad(),它最终会删除所有先前计算的梯度。此外,当我们更新 G 网络的参数时,我们会这样做 - output = netD(fake)。在这里,我们没有使用分离。为什么?

那么,为什么要在上面的代码中分离变量(第 3 行)?

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    原始答案(错误/不完整)

    你说得对,optimizerD 只更新netDnetG 上的渐变在调用netG.zero_grad() 之前没有使用,所以不需要分离,它只是节省时间,因为你没有计算生成器的梯度。

    您基本上也在自己回答其他问题,您不会在第二个块中分离 fake,因为您特别想计算 netG 上的梯度以便能够更新其参数。

    注意the second block real_label 是如何作为fake 的对应标签的,所以如果鉴别器发现假输入是真实的,那么最终损失很小,反之亦然,这正是你想要的发电机。不确定这是否让您感到困惑,但这确实是与在虚假输入上训练判别器相比的唯一区别。

    编辑

    请看 FatPanda 的评论!我原来的答案实际上是不正确的。当调用 .backward() 时,Pytorch 会破坏(部分)计算图。如果在errD_fake.backward() 之前不分离,那么稍后的errG.backward() 调用将无法反向传播到生成器,因为所需的图不再可用(除非您指定retain_graph=True)。我松了一口气,Soumith 犯了同样的错误:D

    【讨论】:

      【解决方案2】:

      投票最多的答案是不正确/不完整!

      检查这个:https://github.com/pytorch/examples/issues/116,看看@plopd 的回答:

      这不是真的。当我们实际更新生成器时,有必要从图中分离fake 以避免通过 G 前向传递噪声。如果我们不分离,那么尽管 fake 不需要用于 D 的梯度更新,它仍将被添加到计算图中,并且由于 backward 传递清除了图中的所有变量(@987654326 @默认),fake在G更新时将不可用。

      这个帖子也澄清了很多:https://zhuanlan.zhihu.com/p/43843694(中文)。

      【讨论】:

        【解决方案3】:

        因为假变量现在是生成器图[1] 的一部分,但您不希望这样。因此,在将其放入鉴别器之前,您必须将其与他“分离”。

        【讨论】:

          【解决方案4】:

          那是因为如果你不在output = netD(fake.detach()).view(-1) 中使用fake.detach(),那么 fake 只是整个计算图中的一些中间变量,它跟踪 netG 和 netD 中的梯度。当您致电netD.backward() 时,图表将被释放。这意味着计算图中没有更多关于 netG() 的梯度信息。然后当你稍后使用 errG.backward() 时,它会导致类似的错误

          尝试第二次向后浏览图表

          如果不使用 fake.detach(),可以使用netD.backward(retain_graph=True)

          【讨论】:

            【解决方案5】:

            让我告诉你。 detach的作用是冻结梯度下降。无论是判别网络还是生成网络,我们都更新logD(G(z))。对于判别网络,冻结G不影响整体梯度更新(即内层函数被认为是一个常数,不影响外层函数求梯度),但反之,如果D被冻结,则有没有办法完成梯度更新。因此,我们在训练生成器时没有使用冻结 D 的梯度。所以,对于生成器,我们确实计算了D的梯度,但是我们没有更新D的权重(只写了optimizer_g.step),所以训练生成器时判别器不会改变。你可能会问,这就是为什么,当你训练判别器时,你需要添加分离。这不是额外的动作吗? 因为我们冻结了梯度,所以我们可以加快训练速度,所以我们可以在可以使用的地方使用它。这不是一项额外的任务。那么我们在训练生成器的时候,因为logD(G(z)),没有办法冻结D的梯度,所以这里就不写detach了。

            【讨论】:

              猜你喜欢
              • 1970-01-01
              • 2020-01-24
              • 2018-04-29
              • 2018-11-06
              • 1970-01-01
              • 2015-06-30
              • 2020-12-20
              • 2015-05-20
              • 1970-01-01
              相关资源
              最近更新 更多