【问题标题】:jax vmap: enforce correct shapejax vmap:强制执行正确的形状
【发布时间】:2019-10-28 12:55:36
【问题描述】:

我正在使用vmap 对我的部分代码进行矢量化处理。这是一个最小的例子,在矢量化之前:

dim = 2
def sum(x):
    a = np.ones((dim,))
    return np.dot(x, a)

num_samples = 100
samples = np.ones((num_samples, dim))

sum(samples[0]) # 2

使用 vmap:

sum = vmap(sum)
sum(samples) # DeviceArray of shape (100,), all entries are 2

但这可能会出错,在矢量化之后:

sum(samples[0]) # DeviceArray of shape (2,2), all entries are 1

这里发生的是samples[0] 的形状为(2,)。矢量化函数调用沿第一个轴拆分其输入参数,因此输入 2 个形状为 (1,) 的数组。由于使用a 进行广播,结果输出再次具有(2,) 的形状并堆叠到(2,2) 数组中。

这对我来说似乎很危险。代码看起来很正常,结果输出很容易被其他隐藏其损坏形状的广播规则消耗。

是否可以强制执行正确的形状?

【问题讨论】:

    标签: python vectorization jax


    【解决方案1】:

    “这对我来说似乎很危险。代码看起来很正常,生成的输出很容易被其他隐藏其损坏形状的广播规则所消耗。”

    请注意,vmap 正在做它应该在这里做的事情,即向量化第零维,而 numpy 广播正在做它应该做的事情。当然,问题是用户给出了一个形状错误的数组,因为vmap 期望在 x 的第零维中输入一个向量化的输入。相反,用户应该写

    sum(samples[0:1])
    

    保持适当的形状。

    换句话说:如果你要将 vmap 应用到一个函数上,你就不能像一开始就没有应用 vmap 一样使用这个函数。您需要考虑函数的行为变化。

    “是否可以强制执行正确的形状?”

    vmap 本身无法强制输入的形状。如果您特别担心用户给函数赋予错误的形状,您可以将其构建到原始函数中。例如,

    def sum(x): 
        if (x.shape[-1] != dim): 
            raise Exception() 
        a = np.ones((dim,)) 
        return np.dot(x, a)
    

    如果你不给它正确的形状,即使在应用vmap之后也会损坏。

    【讨论】:

      猜你喜欢
      • 2021-06-07
      • 1970-01-01
      • 1970-01-01
      • 2022-11-21
      • 2019-10-16
      • 1970-01-01
      • 1970-01-01
      • 2016-09-04
      • 2021-11-05
      相关资源
      最近更新 更多