【问题标题】:Update net.parameters() without .data在没有 .data 的情况下更新 net.parameters()
【发布时间】:2021-04-05 07:36:30
【问题描述】:

有没有办法用其他一些带有梯度的张量来更新网络参数?

我想做如下的事情:

grads = torch.autograd.grad(loss, net.parameters(), 
                                    create_graph=True) 

for param gi in zip(net.parameters(), grads): 
       param -= eps * gi

我希望每个参数都携带 gi 的 grad_fn。

【问题讨论】:

    标签: pytorch autograd


    【解决方案1】:

    您可以通过使用torch.no_grad() 包裹整个循环来做到这一点:

    grads = torch.autograd.grad(loss, net.parameters(), create_graph=True) 
    with torch.no_grad():
        for param, gi in zip(net.parameters(), grads):
            param -= eps*gi
    

    或者,您可以在paramdata 属性上使用就地copy_()

    grads = torch.autograd.grad(loss, net.parameters(), create_graph=True) 
    for param, gi in zip(net.parameters(), grads): 
        param.data.copy_(param.data - eps*gi)
    

    据我测试,这两种方法更新参数的方式相同。


    我还没有找到任何方法来复制grad_fn 属性。作为一种解决方法,您可以复制到 gi 而不是 param,这将使用模型的新参数覆盖 grads 的值:

    grads = torch.autograd.grad(loss, net.parameters(), create_graph=True) 
    for param, gi in zip(net.parameters(), grads): 
        gi.data.copy_(param.data - eps*gi)
    

    如果您需要保持grads 不变,只需在进入循环之前克隆它即可。

    【讨论】:

    • 谢谢。但更新后的参数不携带 grad_fn 的 grads。那是我的问题。他们的 grad_fn 是 requires_grad = True,而 grads 则不同。
    猜你喜欢
    • 1970-01-01
    • 2011-05-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2017-05-23
    • 2013-04-15
    相关资源
    最近更新 更多