【问题标题】:how to do curve fitting using google jax?如何使用 google jax 进行曲线拟合?
【发布时间】:2021-10-20 06:44:46
【问题描述】:

http://implicit-layers-tutorial.org/neural_odes/ 扩展示例我想模仿 scipy 中的曲线拟合函数,scipy.optimize.curve_fit,使用 google jax。要拟合的函数是一阶 ODE。

#Generate toy data for first order ode.

import jax.numpy as jnp
import jax
import numpy as np


#input  data 
u = np.zeros(100)  
u[10:50] = 1
t = np.arange(len(u))
u = jnp.array(u)

#first order ODE
def f(y,t,k,tau,u):
 
  return (k*u[t]-y)/tau
  
#Euler integration
def odeint_euler(f, y0, t, *args):
  def step(state, t):
    y_prev, t_prev = state
    dt = t - t_prev
    y = y_prev + dt * f(y_prev, t_prev, *args)
    return (y, t), y
  _, ys = jax.lax.scan(step, (y0, t[0]), t[1:])
  return ys

pred = odeint_euler(f, jnp.array([0.0]),t,2.,5.,u) 
pred_noise = pred.reshape(-1) +  0.05* np.random.randn(len(pred)) # this is the  data to be fitted

# define loss function 
def loss_function(params,u,targets):
  k,tau = params
  
  pred = odeint_euler(f, jnp.array([0.0]),t,k,tau,u)
  return jnp.sum((pred-targets)**2)      


def update(params, u, targets):
  grads = jax.grad(loss_function)(params,u, targets)
  return [w - 0.0001 * dw for w,dw  in zip(params, grads)] 


updated_params = jnp.array([1.0,2.0]) #initial parameters
for i in range(100):
  updated_params = update(updated_params, u, pred_noise)
print(updated_params)

代码运行良好。但是,与 scipy 曲线拟合相比,这运行得非常慢。即使经过 500、1000 次迭代,解的精度也不好。 上面的代码有什么问题?知道如何使代码运行得更快并获得更准确的解决方案吗?有没有更好的方法来使用jax 进行曲线拟合?

【问题讨论】:

    标签: python curve-fitting jax


    【解决方案1】:

    我发现您的方法存在两个总体问题:

    1. 您的代码运行缓慢的原因是您在 Python 中执行循环,这会在每个循环中产生 JAX 的调度开销。我建议使用 JAX 的内置工具来最小化损失函数;例如:
    from jax.scipy.optimize import minimize
    result = minimize(
        loss_function, x0=jnp.array([1.0,2.0]),
        method='BFGS', args=(u, pred_noise))
    
    1. 您的准确性没有达到 scipy 的原因可能是因为 JAX 默认使用 32 位计算(请参阅 Double (64 bit) Precision)。要在 64 位中运行您的代码,您可以在任何其他导入之前运行此块:
    from jax import config
    config.update('jax_enable_x64', True)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2021-11-14
      • 1970-01-01
      • 2015-07-04
      • 1970-01-01
      • 2021-10-01
      • 2014-03-29
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多