【问题标题】:Accelerating nested for-loops in JAX加速 JAX 中的嵌套 for 循环
【发布时间】:2022-01-03 15:56:00
【问题描述】:

我想使用 JAX 的 jit 方法加速下面示例中的嵌套 for 循环。 但是编译时间很长,编译后的运行时间比不使用jit的版本还要慢。

我是否正确使用jit?我应该在这里使用 JAX 中的其他功能吗?

import time
import jax.numpy as jnp
from jax import jit
from jax import random

key = random.PRNGKey(seed=0)

width = 32
height = 64

w = random.normal(key=key, shape=(height, width))

def forward():
    a = jnp.zeros(shape=(height, width + 1))

    for i in range(height):
        a = a.at[i, 0].add(1.0)

    for j in range(width):
        for i in range(1, height-1):
            z = a[i-1, j] * w[i-1, j] \
                + a[i, j] * w[i, j] \
                + a[i+1, j] * w[i+1, j]
            a = a.at[i, j+1].set(z)

t0 = time.time()
forward()
print(time.time()-t0)

feedforward_jit = jit(forward)

t0 = time.time()
feedforward_jit()
print(time.time()-t0)

【问题讨论】:

  • @jakevdp 我已经偶然发现了这个链接。但是,我看不到如何将页面上的信息转移到我在上面示例中加速代码的问题。如果你能给我一个提示,我会很高兴。
  • 看起来这可以写成某种convolution,这可能会有所帮助。避免变异操作也可能会。

标签: python jit jax


【解决方案1】:

对您的问题的简短回答是:要优化循环,您应该尽一切可能从程序中删除循环。

JAX(如 NumPy)是一种建立在数组操作基础上的语言,任何时候你在数组维度上进行循环,JAX(如 NumPy)都会比你想要的慢。在 JIT 编译期间尤其如此:JAX 将在将操作发送到 XLA 之前将循环展平,并且 XLA 编译时间大致与发送给它的操作数的平方成比例,因此嵌套循环是快速创建 非常编译速度很慢。

那么如何避免这些循环呢?首先,让我们重新定义您的函数,使其接受输入并返回输出(鉴于 JAX 的死代码消除和异步调度,我认为您的初始基准测试并没有告诉您您认为它们是什么;有关一些提示,请参阅 Benchmarking JAX code):

def forward(w):
  height, width = w.shape
  a = jnp.zeros(shape=(height, width + 1))

  for i in range(height):
    a = a.at[i, 0].add(1.0)

  for j in range(width):
    for i in range(1, height-1):
      z = (a[i-1, j] * w[i-1, j]
           + a[i, j] * w[i, j]
           + a[i+1, j] * w[i+1, j])
      a = a.at[i, j+1].set(z)
  return a

第一个循环是可以用单行矢量化更新替换的情况:a = a.at[:, 0].set(1)。查看下一个块的内部循环,代码似乎沿每一列进行卷积。让我们使用jnp.convolve 更有效地做到这一点。使用这两个优化会导致:

def forward2(w):
  height, width = w.shape
  a = jnp.zeros((height, width + 1)).at[:, 0].set(1)
  kernel = jnp.ones(3)
  for j in range(width):
    conv = jnp.convolve(a[:, j] * w[:, j], kernel, mode='valid')
    a = a.at[1:-1, j + 1].set(conv)
  return a

接下来让我们看一下宽度上的循环。这里比较棘手,因为每次迭代都取决于最后一次的结果。我们可以表达的一种方式是使用lax.scan,它是JAX 的内置control flow operators 之一。你可以这样做:

def forward3(w):
  def body(carry, w):
    conv = jnp.convolve(carry * w, kernel, mode='valid')
    out = jnp.zeros_like(w).at[1:-1].set(conv)
    return out, out
  init = jnp.ones(w.shape[0])
  kernel = jnp.ones(3)
  return jnp.vstack([
      init, lax.scan(body, jnp.ones(w.shape[0]), w.T)[1]]).T

我们可以快速确认这三种方法给出相同的输出:

width = 32
height = 64
w = random.normal(key=key, shape=(height, width))

result1 = forward(w)
result2 = forward2(w)
result3 = forward3(w)

assert jnp.allclose(result1, result2)
assert jnp.allclose(result2, result3)

使用 IPython 的 %time 魔法,我们可以大致了解每种方法的计算时间,这里是在 CPU 后端上(注意使用 block_until_ready() 来解释 JAX 的 Asynchronous dispatch):

%time forward(w).block_until_ready()
# CPU times: user 23 s, sys: 248 ms, total: 23.3 s
# Wall time: 22.9 s

%time forward2(w).block_until_ready()
# CPU times: user 117 ms, sys: 866 µs, total: 118 ms
# Wall time: 118 ms

%time forward3(w).block_until_ready()
# CPU times: user 93.2 ms, sys: 2.96 ms, total: 96.1 ms
# Wall time: 94 ms

您可以在https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow 阅读有关 JAX 和控制流的更多信息。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-04-27
    • 1970-01-01
    • 1970-01-01
    • 2012-05-08
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多