【问题标题】:JAX/JIT vs Std Numpy performance: where I am wrong?JAX/JIT vs Std Numpy 性能:我错在哪里?
【发布时间】:2021-11-06 19:23:12
【问题描述】:

这是一个简单的 Simpson 集成代码练习,我已经编写了该代码以接受几个函数来集成一组边界

import numpy as np
def simps(f, a, b, N):
    #N should be even
    dx = (b - a) / N
    x = np.linspace(a, b, N + 1)
    y = f(x)
    w = np.ones_like(y)
    w[2:-1:2] = 2.
    w[1::2]   = 4.
    S = dx / 3 * np.einsum("i...,i...",w,y)
    return S

def funcN(x):
    return np.stack([x**(i/10) * np.exp(-x) for i in range(200)],axis=1)

a = np.arange(0,10,0.1)
b = a+0.05

我在一个 CPU 设备上,然后我得到一个 200 x 100 的数字数组,对应于 int(f_i, a_j,b_j) i:0-199 和 j:0-99

%timeit simps(funcN,a,b, 512)

每个循环 1.13 秒 ± 27.4 毫秒(平均值 ± 标准偏差。7 次运行,每个循环 1 个)

现在考虑以下 JAX/JIT 版本

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from functools import partial
from jax.config import config
config.update("jax_enable_x64", True)   #numpy by default is in double precision

@partial(jit, static_argnums=(0,3))
def jax_simps(f, a,b, N):
    dx = (b - a) / N
    x = jnp.linspace(a, b, N + 1)
    y = f(x)
    w = jnp.ones_like(y)
    w = w.at[2:-1:2].set(2.)
    w = w.at[1::2].set(4.)
    S = dx / 3. * jnp.einsum('i...,i...',w,y)
    return S

@jit
def jax_funcN(x):
    return jnp.stack([x**(i/10) * jnp.exp(-x) for i in range(200)],axis=1)

ja = jnp.arange(0,10,0.1)
jb = ja+0.05

#warm up
jax_simps(jax_funcN,ja,jb, 512).block_until_ready() 

%timeit jax_simps(jax_funcN,ja,jb, 512).block_until_ready() 

我已经验证了这两个代码(纯 Numpy 和 JAX/JIT)给出了相同的结果 因为最大相对误差约为 8. 10^-16.

现在,我得到了以下时间 每个循环 933 毫秒 ± 51.4 毫秒(平均值 ± 标准偏差,7 次运行,每个循环 1 个)

非常接近纯 Numpy。我是否偶然制作了一个非常有效的纯 Numpy 代码???还是我以错误的方式编写了 JAX/JIT 代码?

(注意,使用 Google Collab K80 GPU 时,每个循环的 JAX/JIT 时间下降到 7.19 毫秒,将纯 Numpy 保持在 1 秒/循环的水平)

【问题讨论】:

    标签: python numpy jit jax


    【解决方案1】:

    从您的数据来看,JAX JIT 似乎比 NumPy 在 CPU 上的速度提高了 20%。对于 CPU 执行,NumPy 已经非常优化:撇开 autodiff 之类的东西,对于类似 NumPy 的操作的短序列 JAX 在 CPU 上的主要优势是 XLA 能够融合操作以避免为中间结果分配临时数组,并且对于这个相对较短操作顺序,看起来只购买了大约 20% 的改进。

    现在,JAX 具有其他优势,包括 autodiff、批处理和(如您所提到的)在不更改代码的情况下降低到加速器的能力。但是对于在 CPU 上执行一小段向量化操作,你通常不能比单独使用 NumPy 做得更好。

    附带说明:通过将stack 替换为广播操作,您可以将 NumPy 和 JAX 版本加速 40-50%;例如:

    def funcN(x):
      x = x[:, None, :]
      i = np.arange(200)[:, None]
      return x**(i/10) * np.exp(-x)
    

    【讨论】: