【发布时间】:2019-02-12 10:33:48
【问题描述】:
在多个.backward() 传递之间,我想将渐变设置为零。现在我必须分别为每个组件执行此操作(这里是x 和t),有没有办法为所有受影响的变量“全局”执行此操作? (我想像z.set_all_gradients_to_zero()。)
如果您使用优化器,我知道有optimizer.zero_grad(),但是否也有不使用优化器的直接方法?
import torch
x = torch.randn(3, requires_grad = True)
t = torch.randn(3, requires_grad = True)
y = x + t
z = y + y.flip(0)
z.backward(torch.tensor([1., 0., 0.]), retain_graph = True)
print(x.grad)
print(t.grad)
x.grad.data.zero_() # both gradients need to be set to zero
t.grad.data.zero_()
z.backward(torch.tensor([0., 1., 0.]), retain_graph = True)
print(x.grad)
print(t.grad)
【问题讨论】: