【问题标题】:GPU and `jax` performance mysteriesGPU 和“jax”性能之谜
【发布时间】:2021-12-23 15:02:34
【问题描述】:

最近一直在玩jax,印象很深刻,但是接下来的一组实验让我很困惑:

首先,我们设置计时器实用程序:

import time
def timefunc(foo, *args):
   tic = time.perf_counter()
   tmp = foo(*args)
   toc = time.perf_counter()
   print(toc - tic)
   return tmp

现在,让我们看看当我们计算随机对称矩阵矩阵的特征值时会发生什么,因此(jnpjax.numpy,所以 eigh 在 GPU 上完成)

def jfunc(n):
  tmp = np.random.randn(n, n)
  return jnp.linalg.eigh(tmp + tmp.T)

def nfunc(n):
  tmp = np.random.randn(n, n)
  return np.linalg.eigh(tmp + tmp.T)

现在是时间安排(机器是 nVidia DGX 机器,所以 GPU 是 A100,而 CPU 是一些 AMD EPYC2 部件。

>>> e1 = timefunc(nfunc, 10)
0.0002442029945086688
>>> e2 = timefunc(jfunc, 10)
0.013523647998226807
>>> e1 = timefunc(nfunc, 100)
0.11742364699603058
>>> e2 = timefunc(jfunc, 100)
0.11005625998950563
>>> e1 = timefunc(nfunc, 1000)
0.6572738009999739
>>> e2 = timefunc(jfunc, 1000)
0.5530761769914534
>>> e1 = timefunc(nfunc, 10000)
36.22587636699609
>>> e2 = timefunc(jfunc, 10000)
8.867857075005304

您会注意到交叉点在 1000 左右。最初,我认为这是因为将内容移入/移出 GPU 的开销,但如果您定义另一个函数:

def jjfunc(n):
  key=jax.random.PRNGKey(0)
  tmp = jax.random.normal(key, [n, n])
  return jnp.linalg.eigh(tmp + tmp.T)


>>> e1=timefunc(jjfunc, 10)
0.01886096798989456
>>> e1=timefunc(jjfunc, 100)
0.2756766739912564
>>> e1=timefunc(jjfunc, 1000)
0.7205733209993923
>>> e1=timefunc(jjfunc, 10000)
6.8624101399909705

请注意,小示例实际上比将 numpy 数组移动到 GPU 并返回要慢得多。

所以,我的问题是:发生了什么事,是否有灵丹妙药?这是 jax 的实现错误吗?

【问题讨论】:

    标签: python numpy performance jax


    【解决方案1】:

    出于以下几个原因,我认为您的时间安排不能反映实际的 JAX 与 numpy 的性能:

    • JAX 的计算模型使用Asynchronous Dispatch,这意味着JAX 操作在计算完成之前返回。如该链接所述,您可以使用 block_until_ready() 方法来确保您正在计时计算而不是调度。
    • 因为像eigh 这样的操作是JIT-compiled by default,所以在给定大小第一次运行它们时会产生一次性编译成本。由于 JAX 会缓存以前的编译,因此后续运行会更快。
    • 您的计算确实被设备传输成本所挫败。直接测量最容易看出:
      def transfer(n):
        tmp = np.random.randn(n, n)
        return jnp.array(tmp).block_until_ready()
      timefunc(transfer, 10000);
      # 4.600406924000026
      
    • 您的jjfunc 结合了eigh 呼叫和jax.random.normal 呼叫。后者比 numpy 的随机数生成要慢,我相信它在小的 n 中占主导地位。
    • 与 JAX 无关,但通常使用 time.time 分析 Python 代码可能会给您带来误导性的结果。 timeit 之类的模块更适合这类事情,尤其是当您处理在几分之一秒内完成的微基准测试时。

    如果您对 JAX 与 Numpy 版本的算法的准确基准测试感兴趣,我建议您完全隔离您对基准测试感兴趣的操作(即生成数据并在基准测试之外进行任何设备传输)。阅读 JAX 中 Asynchronous Dispatch 中与基准测试相关的建议,并查看 Python's timeit Docs 以获取有关获得小代码 sn-ps 准确计时的提示(尽管我发现 %timeit magic 在使用 IPython 或 Jupyter 时更方便笔记本)。

    【讨论】:

    • 我对 JIT 的评论有点困惑。像linalg.eigh 这样的“内置”能从JIT-ing 中受益吗?即使他们这样做了,jfuncjjfunc 在同一个会话中运行,所以,正如你所说,至少其中一个(jjfunc)应该从 JIT-ing 中受益。没有?
    • linalg.eigh 默认情况下是 JIT 编译的 (source):它确实受益于这种 JITting,因为它在后续运行中执行得更快,但它不会受益于来自用户,因为它已经编译。这有意义吗?
    • 确实如此,但这意味着所有 jjfunc 运行都受益于 JIT。现在,关于您的timeit 评论:当然对于需要这么长时间的代码,我们使用哪个计时器并不重要,不是吗?
    • 是的,我的意思是要小心 JIT 的影响,否则编译/未编译导致的时间差异可能会被错误归因于其他原因。关于timeit:它做了很多事情来确保其他噪声源不会影响你的基准测试,所以无论你的运行需要多长时间,使用它都是一个好习惯。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2011-09-01
    • 2012-07-09
    • 1970-01-01
    • 2011-01-10
    • 1970-01-01
    • 2017-04-04
    相关资源
    最近更新 更多