【发布时间】:2019-10-28 12:55:36
【问题描述】:
我正在使用vmap 对我的部分代码进行矢量化处理。这是一个最小的例子,在矢量化之前:
dim = 2
def sum(x):
a = np.ones((dim,))
return np.dot(x, a)
num_samples = 100
samples = np.ones((num_samples, dim))
sum(samples[0]) # 2
使用 vmap:
sum = vmap(sum)
sum(samples) # DeviceArray of shape (100,), all entries are 2
但这可能会出错,在矢量化之后:
sum(samples[0]) # DeviceArray of shape (2,2), all entries are 1
这里发生的是samples[0] 的形状为(2,)。矢量化函数调用沿第一个轴拆分其输入参数,因此输入 2 个形状为 (1,) 的数组。由于使用a 进行广播,结果输出再次具有(2,) 的形状并堆叠到(2,2) 数组中。
这对我来说似乎很危险。代码看起来很正常,结果输出很容易被其他隐藏其损坏形状的广播规则消耗。
是否可以强制执行正确的形状?
【问题讨论】:
标签: python vectorization jax