【问题标题】:How to Use AutoGrad Packages?如何使用 AutoGrad 包?
【发布时间】:2018-02-02 13:07:57
【问题描述】:

我正在尝试做一件简单的事情:使用 autograd 获取梯度并进行梯度下降:

import tangent

def model(x):
    return a*x + b

def loss(x,y):
    return (y-model(x))**2.0

在获得输入-输出对的损失后,我想获得梯度损失:

    l = loss(1,2)
    # grad_a = gradient of loss wrt a?
    a = a - grad_a
    b = b - grad_b

但是库教程没有展示如何获取关于 a 或 b 的梯度,即参数 so,既不是 autograd 也不是 tangent。

【问题讨论】:

  • 他们don't show是什么意思?
  • @sascha 是的,我在切线之前先尝试过。他们仅展示了一个带有 tanh 的示例; 1.不是函数的组合,然后2.它们的函数没有任何参数,即。它只是 x,所以没有偏导数。

标签: python machine-learning deep-learning gradient-descent autograd


【解决方案1】:

你可以用 grad 函数的第二个参数来指定:

def f(x,y):
    return x*x + x*y

f_x = grad(f,0) # derivative with respect to first argument
f_y = grad(f,1) # derivative with respect to second argument

print("f(2,3)   = ", f(2.0,3.0))
print("f_x(2,3) = ", f_x(2.0,3.0)) 
print("f_y(2,3) = ", f_y(2.0,3.0))

在您的情况下,“a”和“b”应该是损失函数的输入,损失函数将它们传递给模型以计算导数。

我刚刚回答了一个类似的问题: Partial Derivative using Autograd

【讨论】:

    【解决方案2】:

    这可能会有所帮助:

    import autograd.numpy as np
    from autograd import grad
    def tanh(x):
      y=np.exp(-x)
      return (1.0-y)/(1.0+y)
    
    grad_tanh = grad(tanh)
    
    print(grad_tanh(1.0))
    
    e=0.00001
    g=(tanh(1+e)-tanh(1))/e
    print(g)
    

    输出:

    0.39322386648296376
    0.39322295790622513
    

    您可以创建以下内容:

    import autograd.numpy as np
    from autograd import grad  # grad(f) returns f'
    
    def f(x): # tanh
      y = np.exp(-x)
      return  (1.0 - y) / ( 1.0 + y)
    
    D_f   = grad(f) # Obtain gradient function
    D2_f = grad(D_f)# 2nd derivative
    D3_f = grad(D2_f)# 3rd derivative
    D4_f = grad(D3_f)# etc.
    D5_f = grad(D4_f)
    D6_f = grad(D5_f)
    
    import  matplotlib.pyplot  as plt
    plt.subplots(figsize = (9,6), dpi=153 )
    x = np.linspace(-7, 7, 100)
    plt.plot(x, list(map(f, x)),
             x, list(map(D_f , x)),
             x, list(map(D2_f , x)),
             x, list(map(D3_f , x)),
             x, list(map(D4_f , x)),
             x, list(map(D5_f , x)),
             x, list(map(D6_f , x)))
    plt.show()
    

    输出:

    【讨论】:

      猜你喜欢
      • 2019-07-02
      • 1970-01-01
      • 2019-08-31
      • 2020-07-11
      • 1970-01-01
      • 2017-10-27
      • 1970-01-01
      • 2021-07-23
      • 1970-01-01
      相关资源
      最近更新 更多