【问题标题】:Understanding pytorch graph generation了解 pytorch 图生成
【发布时间】:2021-10-03 10:56:05
【问题描述】:

如果我运行代码:

import torch

x = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b

loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()

loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()

pytorch 向我吐出错误“尝试再次向后浏览图表”。我的理解是再次调用损失计算线实际上并没有改变计算图,这就是我得到这个错误的原因。但是,当我调用代码时:

import torch

x = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b

loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()

z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss.backward()

它工作正常(没有错误),我不明白为什么会这样,无论哪种情况,我都没有对计算图进行任何更改?

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    这是个好问题。在我看来,这对于充分掌握 PyTorch 的这一特性尤为重要。在处理复杂的设置时,这一点至关重要,无论是涉及多次反向传递还是部分反向传递。

    在这两个示例中,您的计算图是:

    y ---------------------------->|
    b ----------->|                |
    w ------->|                    |
    x --> x @ w + b = z --> BCE(z, y) = loss
    

    然而,我们所说的“计算图”只是该结果计算中存在的依赖关系的表示。该结果与导致最终计算的张量相关联的方式,图的中间结果。当您计算 loss 时,loss 和所有其他张量之间仍然存在一个链接,这是计算反向传递所必需的。

    第一个场景

    在您的第一个示例中,您计算​​loss,它本身会创建一个“计算图”。注意出现在loss 变量上的grad_fn 属性。这是用于导航回图表的回调函数。在您的情况下,F.binary_cross_entropy_with_logits 将输出grad_fn=<BinaryCrossEntropyWithLogitsBackward>。话虽如此,您通过调用backward() 成功计算了反向传播,这样做使用graph_fn 的函数并更新参数的grad 属性向上传播图形。然后,您使用相同的z 定义loss,即与上一张图相关联的那个。您实际上是从上面的计算图转到以下计算图:

    y ---------------------------->|
    b ----------->|                |
    w ------->|                    |
    x --> x @ w + b = z --> BCE(z, y) = loss
                       \--> BCE(z, y) = loss # 2nd definition of loss
    

    loss 的第二个定义覆盖了loss 的先前值,是的。但是,它不会影响仍然存在的图表的第一部分:正如我所解释的,z 仍然与初始张量 xwb 相关联。

    默认情况下,在向后传递期间,激活被释放。这意味着您将无法执行第二遍。总结您的第一个示例,第二个loss.backward() 将通过loss(新的)grad_fn,然后到达初始z,其激活已被释放。这会导致您遇到的错误:

    尝试第二次向后通过图形

    第二个场景

    在第二个示例中,您通过重新计算叶张量 x 中的 z 来重新定义整个网络,因此 loss 具有中间输出 z 和叶张量 y

    从概念上讲,计算图的状态是:

    y ---------------------------->|
    b ----------->|                |
    w ------->|                    |
    x --> x @ w + b = z --> BCE(z, y) = loss
      \-> x @ w + b = z --> BCE(z, y) = loss # 2nd definition of loss
    

    这意味着通过第一次调用loss.backward 来对初始图进行反向传递。然后,在重新定义 zloss 之后,您最终会创建一个新图:上图的第二个分支。由于您不在同一个图表上,因此第二次反向传递最终会起作用。

    【讨论】:

    • 感谢您的详细解答!
    猜你喜欢
    • 2020-05-12
    • 2019-09-17
    • 2013-02-26
    • 2018-04-26
    • 1970-01-01
    • 2017-12-25
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多