【发布时间】:2021-10-20 06:44:46
【问题描述】:
从http://implicit-layers-tutorial.org/neural_odes/ 扩展示例我想模仿 scipy 中的曲线拟合函数,scipy.optimize.curve_fit,使用 google jax。要拟合的函数是一阶 ODE。
#Generate toy data for first order ode.
import jax.numpy as jnp
import jax
import numpy as np
#input data
u = np.zeros(100)
u[10:50] = 1
t = np.arange(len(u))
u = jnp.array(u)
#first order ODE
def f(y,t,k,tau,u):
return (k*u[t]-y)/tau
#Euler integration
def odeint_euler(f, y0, t, *args):
def step(state, t):
y_prev, t_prev = state
dt = t - t_prev
y = y_prev + dt * f(y_prev, t_prev, *args)
return (y, t), y
_, ys = jax.lax.scan(step, (y0, t[0]), t[1:])
return ys
pred = odeint_euler(f, jnp.array([0.0]),t,2.,5.,u)
pred_noise = pred.reshape(-1) + 0.05* np.random.randn(len(pred)) # this is the data to be fitted
# define loss function
def loss_function(params,u,targets):
k,tau = params
pred = odeint_euler(f, jnp.array([0.0]),t,k,tau,u)
return jnp.sum((pred-targets)**2)
def update(params, u, targets):
grads = jax.grad(loss_function)(params,u, targets)
return [w - 0.0001 * dw for w,dw in zip(params, grads)]
updated_params = jnp.array([1.0,2.0]) #initial parameters
for i in range(100):
updated_params = update(updated_params, u, pred_noise)
print(updated_params)
代码运行良好。但是,与 scipy 曲线拟合相比,这运行得非常慢。即使经过 500、1000 次迭代,解的精度也不好。
上面的代码有什么问题?知道如何使代码运行得更快并获得更准确的解决方案吗?有没有更好的方法来使用jax 进行曲线拟合?
【问题讨论】:
标签: python curve-fitting jax