【发布时间】:2020-10-11 05:41:25
【问题描述】:
我正在尝试更深入地了解 Pytorch 的 autograd 是如何工作的。我无法解释以下结果:
import torch
def fn(a):
b = torch.tensor(5,dtype=torch.float32,requires_grad=True)
return a*b
a = torch.tensor(10,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)
输出是张量(5.)。但我的问题是变量 b 是在函数中创建的,因此应该在函数返回 a*b 后从内存中删除,对吗?因此,当我向后调用时,b 的值如何仍然存在以允许此计算? 据我了解,Pytorch 中的每个操作都有一个上下文变量,它跟踪“哪个”张量用于向后计算,并且每个张量中也存在版本,如果版本发生变化,那么向后应该会引发错误,对吧?
现在当我尝试运行以下代码时,
import torch
def fn(a):
b = a**2
for i in range(5):
b *= b
return b
a = torch.tensor(10,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)
我收到以下错误:梯度计算所需的变量之一已被就地操作修改:[torch.FloatTensor []],它是 MulBackward0 的输出 0,是版本 5;而是预期的版本 4。提示:使用 torch.autograd.set_detect_anomaly(True) 启用异常检测以查找未能计算其梯度的操作。
但是如果我运行下面的代码,没有错误:
import torch
def fn(a):
b = a**2
for i in range(2):
b = b*b
return b
def fn2(a):
b = a**2
c = a**2
for i in range(2):
c *= b
return c
a = torch.tensor(5,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)
output2 = fn2(a)
output2.backward()
print(a.grad)
这个的输出是:
张量(625000.)
张量(643750.)
因此,对于具有相当多变量的标准计算图,在同一个函数中,我能够理解计算图是如何工作的。但是,当在调用后向函数之前变量发生变化时,我在理解结果时遇到了很多麻烦。谁能解释一下?
【问题讨论】:
标签: python pytorch gradient-descent autograd