您可以通过使用torch.no_grad() 包裹整个循环来做到这一点:
grads = torch.autograd.grad(loss, net.parameters(), create_graph=True)
with torch.no_grad():
for param, gi in zip(net.parameters(), grads):
param -= eps*gi
或者,您可以在param 的data 属性上使用就地copy_():
grads = torch.autograd.grad(loss, net.parameters(), create_graph=True)
for param, gi in zip(net.parameters(), grads):
param.data.copy_(param.data - eps*gi)
据我测试,这两种方法更新参数的方式相同。
我还没有找到任何方法来复制grad_fn 属性。作为一种解决方法,您可以复制到 gi 而不是 param,这将使用模型的新参数覆盖 grads 的值:
grads = torch.autograd.grad(loss, net.parameters(), create_graph=True)
for param, gi in zip(net.parameters(), grads):
gi.data.copy_(param.data - eps*gi)
如果您需要保持grads 不变,只需在进入循环之前克隆它即可。