【发布时间】:2021-12-23 15:02:34
【问题描述】:
最近一直在玩jax,印象很深刻,但是接下来的一组实验让我很困惑:
首先,我们设置计时器实用程序:
import time
def timefunc(foo, *args):
tic = time.perf_counter()
tmp = foo(*args)
toc = time.perf_counter()
print(toc - tic)
return tmp
现在,让我们看看当我们计算随机对称矩阵矩阵的特征值时会发生什么,因此(jnp 是 jax.numpy,所以 eigh 在 GPU 上完成)
def jfunc(n):
tmp = np.random.randn(n, n)
return jnp.linalg.eigh(tmp + tmp.T)
def nfunc(n):
tmp = np.random.randn(n, n)
return np.linalg.eigh(tmp + tmp.T)
现在是时间安排(机器是 nVidia DGX 机器,所以 GPU 是 A100,而 CPU 是一些 AMD EPYC2 部件。
>>> e1 = timefunc(nfunc, 10)
0.0002442029945086688
>>> e2 = timefunc(jfunc, 10)
0.013523647998226807
>>> e1 = timefunc(nfunc, 100)
0.11742364699603058
>>> e2 = timefunc(jfunc, 100)
0.11005625998950563
>>> e1 = timefunc(nfunc, 1000)
0.6572738009999739
>>> e2 = timefunc(jfunc, 1000)
0.5530761769914534
>>> e1 = timefunc(nfunc, 10000)
36.22587636699609
>>> e2 = timefunc(jfunc, 10000)
8.867857075005304
您会注意到交叉点在 1000 左右。最初,我认为这是因为将内容移入/移出 GPU 的开销,但如果您定义另一个函数:
def jjfunc(n):
key=jax.random.PRNGKey(0)
tmp = jax.random.normal(key, [n, n])
return jnp.linalg.eigh(tmp + tmp.T)
>>> e1=timefunc(jjfunc, 10)
0.01886096798989456
>>> e1=timefunc(jjfunc, 100)
0.2756766739912564
>>> e1=timefunc(jjfunc, 1000)
0.7205733209993923
>>> e1=timefunc(jjfunc, 10000)
6.8624101399909705
请注意,小示例实际上比将 numpy 数组移动到 GPU 并返回要慢得多。
所以,我的问题是:发生了什么事,是否有灵丹妙药?这是 jax 的实现错误吗?
【问题讨论】:
标签: python numpy performance jax