对您的问题的简短回答是:要优化循环,您应该尽一切可能从程序中删除循环。
JAX(如 NumPy)是一种建立在数组操作基础上的语言,任何时候你在数组维度上进行循环,JAX(如 NumPy)都会比你想要的慢。在 JIT 编译期间尤其如此:JAX 将在将操作发送到 XLA 之前将循环展平,并且 XLA 编译时间大致与发送给它的操作数的平方成比例,因此嵌套循环是快速创建 非常编译速度很慢。
那么如何避免这些循环呢?首先,让我们重新定义您的函数,使其接受输入并返回输出(鉴于 JAX 的死代码消除和异步调度,我认为您的初始基准测试并没有告诉您您认为它们是什么;有关一些提示,请参阅 Benchmarking JAX code):
def forward(w):
height, width = w.shape
a = jnp.zeros(shape=(height, width + 1))
for i in range(height):
a = a.at[i, 0].add(1.0)
for j in range(width):
for i in range(1, height-1):
z = (a[i-1, j] * w[i-1, j]
+ a[i, j] * w[i, j]
+ a[i+1, j] * w[i+1, j])
a = a.at[i, j+1].set(z)
return a
第一个循环是可以用单行矢量化更新替换的情况:a = a.at[:, 0].set(1)。查看下一个块的内部循环,代码似乎沿每一列进行卷积。让我们使用jnp.convolve 更有效地做到这一点。使用这两个优化会导致:
def forward2(w):
height, width = w.shape
a = jnp.zeros((height, width + 1)).at[:, 0].set(1)
kernel = jnp.ones(3)
for j in range(width):
conv = jnp.convolve(a[:, j] * w[:, j], kernel, mode='valid')
a = a.at[1:-1, j + 1].set(conv)
return a
接下来让我们看一下宽度上的循环。这里比较棘手,因为每次迭代都取决于最后一次的结果。我们可以表达的一种方式是使用lax.scan,它是JAX 的内置control flow operators 之一。你可以这样做:
def forward3(w):
def body(carry, w):
conv = jnp.convolve(carry * w, kernel, mode='valid')
out = jnp.zeros_like(w).at[1:-1].set(conv)
return out, out
init = jnp.ones(w.shape[0])
kernel = jnp.ones(3)
return jnp.vstack([
init, lax.scan(body, jnp.ones(w.shape[0]), w.T)[1]]).T
我们可以快速确认这三种方法给出相同的输出:
width = 32
height = 64
w = random.normal(key=key, shape=(height, width))
result1 = forward(w)
result2 = forward2(w)
result3 = forward3(w)
assert jnp.allclose(result1, result2)
assert jnp.allclose(result2, result3)
使用 IPython 的 %time 魔法,我们可以大致了解每种方法的计算时间,这里是在 CPU 后端上(注意使用 block_until_ready() 来解释 JAX 的 Asynchronous dispatch):
%time forward(w).block_until_ready()
# CPU times: user 23 s, sys: 248 ms, total: 23.3 s
# Wall time: 22.9 s
%time forward2(w).block_until_ready()
# CPU times: user 117 ms, sys: 866 µs, total: 118 ms
# Wall time: 118 ms
%time forward3(w).block_until_ready()
# CPU times: user 93.2 ms, sys: 2.96 ms, total: 96.1 ms
# Wall time: 94 ms
您可以在https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow 阅读有关 JAX 和控制流的更多信息。