请仔细阅读backward() 上的文档以更好地理解它。
默认情况下,pytorch 期望为网络的 last 输出调用 backward() - 损失函数。损失函数总是输出一个标量,因此 scalar 损失与所有其他变量/参数的梯度是明确定义的(使用链式法则)。
因此,默认情况下,backward() 在标量张量上调用并且不需要任何参数。
例如:
a = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float, requires_grad=True)
for i in range(2):
for j in range(3):
out = a[i,j] * a[i,j]
out.backward()
print(a.grad)
产量
tensor([[ 2., 4., 6.],
[ 8., 10., 12.]])
如预期:d(a^2)/da = 2a。
但是,当您在 2×3 out 张量(不再是标量函数)上调用 backward 时,您期望 a.grad 是什么?你实际上需要一个 2×3×2×3 的输出:d out[i,j] / d a[k,l](!)
Pytorch 不支持这种非标量函数导数。相反,pytorch 假设 out 只是一个中间张量,并且在“上游”某处有一个标量损失函数,通过链式规则提供 d loss/ d out[i,j]。这个“上游”渐变的大小是 2×3,在这种情况下,这实际上是您提供的 backward 参数:out.backward(g) where g_ij = d loss/ d out_ij。
然后通过链式法则d loss / d a[i,j] = (d loss/d out[i,j]) * (d out[i,j] / d a[i,j])计算梯度
因为您提供了 a 作为“上游”渐变,所以您得到了
a.grad[i,j] = 2 * a[i,j] * a[i,j]
如果您要提供“上游”渐变为全1
out.backward(torch.ones(2,3))
print(a.grad)
产量
tensor([[ 2., 4., 6.],
[ 8., 10., 12.]])
正如预期的那样。
这一切都在链式法则中。