【问题标题】:Explanation behind the following Pytorch results以下 Pytorch 结果背后的解释
【发布时间】:2020-10-11 05:41:25
【问题描述】:

我正在尝试更深入地了解 Pytorch 的 autograd 是如何工作的。我无法解释以下结果:

import torch
def fn(a):
 b = torch.tensor(5,dtype=torch.float32,requires_grad=True)
 return a*b 

a  = torch.tensor(10,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)

输出是张量(5.)。但我的问题是变量 b 是在函数中创建的,因此应该在函数返回 a*b 后从内存中删除,对吗?因此,当我向后调用时,b 的值如何仍然存在以允许此计算? 据我了解,Pytorch 中的每个操作都有一个上下文变量,它跟踪“哪个”张量用于向后计算,并且每个张量中也存在版本,如果版本发生变化,那么向后应该会引发错误,对吧?

现在当我尝试运行以下代码时,

import torch
def fn(a):
 b = a**2
 for i in range(5):
   b *= b
 return b 

a  = torch.tensor(10,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)

我收到以下错误:梯度计算所需的变量之一已被就地操作修改:[torch.FloatTensor []],它是 MulBackward0 的输出 0,是版本 5;而是预期的版本 4。提示:使用 torch.autograd.set_detect_anomaly(True) 启用异常检测以查找未能计算其梯度的操作。

但是如果我运行下面的代码,没有错误:

import torch
def fn(a):
  b = a**2
  for i in range(2):
    b = b*b
  return b

def fn2(a):
  b = a**2
  c = a**2
  for i in range(2):
    c *= b
  return c

a  = torch.tensor(5,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)
output2 = fn2(a)
output2.backward()
print(a.grad)

这个的输出是:

张量(625000.)

张量(643750.)

因此,对于具有相当多变量的标准计算图,在同一个函数中,我能够理解计算图是如何工作的。但是,当在调用后向函数之前变量发生变化时,我在理解结果时遇到了很多麻烦。谁能解释一下?

【问题讨论】:

    标签: python pytorch gradient-descent autograd


    【解决方案1】:

    请注意b *=bb = b*b 不同。

    这可能令人困惑,但底层操作各不相同。

    b *=b 的情况下,会发生就地操作,这会扰乱渐变,因此会扰乱RuntimeError

    b = b*b 的情况下,两个张量对象相乘,结果对象被命名为b。因此,当您以这种方式运行时,没有RuntimeError

    这是一个关于底层python操作的SO问题:The difference between x += y and x = x + y

    现在fn 在第一种情况下和fn2 在第二种情况下有什么区别?操作c*=b 不会破坏从cb 的图链接。操作c*=c 将使得不可能有通过操作连接两个张量的图。

    好吧,我不能使用张量来展示这一点,因为它们会引发 RuntimeError。所以我会尝试使用 python 列表。

    >>> x = [1,2]
    >>> y = [3]
    >>> id(x), id(y)
    (140192646516680, 140192646927112)
    >>>
    >>> x += y
    >>> x, y
    ([1, 2, 3], [3])
    >>> id(x), id(y)
    (140192646516680, 140192646927112)
    

    请注意,没有创建新对象。所以不可能从output 追踪到初始变量。我们无法区分object_140192646516680 是输出还是输入。那么如何用它创建一个图表..

    考虑以下替代情况:

    >>> a = [1,2]
    >>> b = [3]
    >>>
    >>> id(a), id(b)
    (140192666168008, 140192666168264)
    >>>
    >>> a = a + b
    >>> a, b
    ([1, 2, 3], [3])
    >>> id(a), id(b)
    (140192666168328, 140192666168264)
    >>>
    

    请注意,新列表 a 实际上是带有 id 140192666168328 的新对象。在这里,我们可以追踪到object_140192666168328 来自其他两个对象object_140192666168008object_140192666168264 之间的addition operation。因此可以动态创建图形,并且可以将梯度从output 传播回之前的层。

    【讨论】:

    • 嗨。感谢您的就地澄清。能否请您详细说明最后一点?
    • 扩展了答案以进行更多描述。希望这会有所帮助。
    • c*=c 与 c*=b 有何不同?似乎在带有列表的示例中,当您应用 x+=y 时,对象的 id 没有改变?
    • 不同之处在于c*=c 会蚕食自己的历史,因为图边的起点和终点是相同的对象,因此图会崩溃。 c*=b 允许从一个不同的对象 c 到另一个对象 b 的图形链接。
    猜你喜欢
    • 1970-01-01
    • 2020-11-08
    • 2017-06-13
    • 2017-12-08
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多