【问题标题】:How to compute gradient of power function wrt exponent in PyTorch?如何在 PyTorch 中计算幂函数 wrt 指数的梯度?
【发布时间】:2019-11-04 17:34:21
【问题描述】:

我正在尝试计算

的梯度
out = x.sign()*torch.pow(x.abs(), alpha)

关于 alpha。

到目前为止,我尝试了以下方法:

class Power(nn.Module):
  def __init__(self, alpha=2.):
    super(Power, self).__init__()
    self.alpha = nn.Parameter(torch.tensor(alpha))

  def forward(self, x):
    return x.sign()*torch.abs(x)**self.alpha

但是这门课一直在给我nan 训练我的网络。我希望看到像grad=out*torch.log(x) 这样的东西,但无法做到。例如,此代码不返回任何内容:

alpha_rooting = Power()
x = torch.randn((1), device='cpu', dtype=torch.float)
out = (alpha_rooting(x)).sum()
out.backward()
print(out.grad)

我也试图为此使用autograd,但也没有运气。我应该如何解决这个问题?谢谢。

【问题讨论】:

    标签: pytorch autograd


    【解决方案1】:

    您编写的Power() 类按预期工作。您实际使用它的方式存在问题。渐变存储在该变量的.grad 中,而不是您在上面使用的out 变量中。您可以更改代码如下。

    alpha_rooting = Power()
    x = torch.randn((1), device='cpu', dtype=torch.float)
    out = (alpha_rooting(x)).sum()
    
    # compute gradients of all parameters with respect to out (dout/dparam)
    out.backward()
    # print gradient of alpha
    # Note that gradients are store in .grad of parameter not out variable
    print(alpha_rooting.alpha.grad)
    
    # compare if it is approximately correct to exact grads
    err = (alpha_rooting.alpha.grad - out*torch.log(x))**2 
    if (err <1e-8):
        print("Gradients are correct")
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2020-03-27
      • 2021-09-11
      • 2020-05-12
      • 1970-01-01
      • 2020-03-10
      相关资源
      最近更新 更多