【问题标题】:vmap ops.index_update in JaxJax 中的 vmap ops.index_update
【发布时间】:2020-11-03 01:55:11
【问题描述】:

我在下面有以下代码,它使用了一个简单的 for 循环。我只是想知道是否有办法 vmap 它?这是原始代码:

import numpy as np 
import jax.numpy as jnp
import jax.scipy.signal as jscp
from scipy import signal
import jax

data = np.random.rand(192,334)

a = [1,-1.086740193996892,0.649914553946275,-0.124948974636730]
b = [0.054778173164082,0.164334519492245,0.164334519492245,0.054778173164082]
impulse = signal.lfilter(b, a, [1] + [0]*99) 
impulse_20 = impulse[:20]
impulse_20 = jnp.asarray(impulse_20)

@jax.jit
def filter_jax(y):
    for ind in range(0, len(y)):
      y = jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])
    return y

jnpData = jnp.asarray(data)

%timeit filter_jax(jnpData).block_until_ready()

这是我使用 vmap 的尝试:

def paraUpdate(y, ind):
    return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])

@jax.jit
def filter_jax2(y):
  ranger = range(0, len(y))
  return jax.vmap(paraUpdate, y)(ranger)

但我收到以下错误:

TypeError: vmap in_axes 必须是 int、None 或(嵌套)容器 将这些类型作为叶子,但得到了 使用 跟踪

我有点困惑,因为范围是 int 类型,所以我不太确定发生了什么。

最后,我试图让这个小部分尽可能地优化,以获得最短的时间。

【问题讨论】:

    标签: python performance numpy optimization jax


    【解决方案1】:

    jax.vmap 可以表达单个操作独立应用于输入的多个轴的功能。您的功能有点不同:您将单个操作迭代应用于单个输入。

    幸运的是,JAX 提供了lax.scan 可以处理这种情况。实现看起来像这样:

    from jax import lax
    
    def paraUpdate(y, ind):
        return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19]), ind
    
    @jax.jit
    def filter_jax2(y):
      ranger = jnp.arange(len(y))
      return lax.scan(paraUpdate, y, ranger)[0]
    
    print(np.allclose(filter_jax(jnpData), filter_jax2(jnpData)))
    # True
    
    %timeit filter_jax(jnpData).block_until_ready()
    # 10 loops, best of 3: 28.6 ms per loop
    
    %timeit filter_jax2(jnpData).block_until_ready()
    # 1000 loops, best of 3: 519 µs per loop
    

    如果您更改算法以便将操作应用于数组中的 每个 列而不是前 N 列,则可以使用 @987654327 表示@像这样:

    @jax.jit
    def filter_jax3(y):
      f = lambda col: jscp.convolve(impulse_20, col)[:-19]
      return jax.vmap(f, in_axes=1, out_axes=1)(y)
    

    【讨论】:

    • 这解决了我遇到的问题,谢谢!不过,我确实有另一个问题,我注意到 JAX 等效项不再匹配此代码的原始纯 python 版本,这可能是什么原因造成的? def filter_cpu(y): y[:,range(0, len(y))] = signal.convolve(impulse_20[np.newaxis, :], y[:,range(0, len(y))])[:, :-19] return y
    • JAX 默认为 32 位计算,除非您设置 X64 标志; scipy 默认为 64 位计算。这可能就是区别。另请注意,如果您想直接使用 jax,可以使用 2D 卷积:jax.readthedocs.io/en/latest/_autosummary/…
    猜你喜欢
    • 1970-01-01
    • 2021-06-07
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2021-11-05
    • 2021-08-11
    • 1970-01-01
    • 2022-11-21
    相关资源
    最近更新 更多