TL;DR:首先:prange 与 range 相同,除非您将并行添加到 jit,例如 njit(parallel=True)。如果您尝试这样做,您会看到有关“不支持的缩减”的异常 - 这是因为 Numba 将 prange 的范围限制为 “纯”循环 和 “不纯循环”与 numba-支持减少,并将确保它属于这些类别的责任归于用户。
这在numbas prange (version 0.42)的文档中有明确说明:
1.10.2。显式并行循环
此代码转换过程的另一个功能是支持显式并行循环。可以使用 Numba 的 prange 而不是 range 来指定可以并行化循环。用户需要确保循环没有交叉迭代依赖,除了支持的归约。
cmets 所指的“不纯”在该文档中称为“交叉迭代依赖项”。这样的“交叉迭代依赖”是一个在循环之间变化的变量。一个简单的例子是:
def func(n):
a = 0
for i in range(n):
a += 1
return a
这里的变量a 取决于它在循环开始之前的值和循环执行了多少次迭代。这就是“交叉迭代依赖”或“不纯”循环的含义。
显式并行化这样一个循环的问题是迭代是并行执行的,但每次迭代都需要知道其他迭代在做什么。不这样做会导致错误的结果。
让我们暂时假设prange 将产生4 个worker,我们将4 作为n 传递给函数。一个完全幼稚的实现会做什么?
Worker 1 starts, gets a i = 1 from `prange`, and reads a = 0
Worker 2 starts, gets a i = 2 from `prange`, and reads a = 0
Worker 3 starts, gets a i = 3 from `prange`, and reads a = 0
Worker 1 executed the loop and sets `a = a + 1` (=> 1)
Worker 3 executed the loop and sets `a = a + 1` (=> 1)
Worker 4 starts, gets a i = 4 from `prange`, and reads a = 2
Worker 2 executed the loop and sets `a = a + 1` (=> 1)
Worker 4 executed the loop and sets `a = a + 1` (=> 3)
=> Loop ended, function return 3
不同工作人员读取、执行和写入a 的顺序可以是任意的,这只是一个示例。它也可能(偶然)产生正确的结果!这通常称为Race condition。
更复杂的prange 会如何识别存在这样的交叉迭代依赖关系?
共有三个选项:
- 根本不要并行化它。
- 实施一种机制,其中工作人员共享变量。这里的典型示例是 Locks(这可能会产生高开销)。
- 认识到这是可以并行化的归约。
鉴于我对 numba 文档的理解(再次重复):
要求用户确保循环没有交叉迭代依赖,除了支持的归约。
Numba 会:
- 如果是已知的缩减,则使用模式将其并行化
- 如果不是已知的归约,则抛出异常
很遗憾,目前尚不清楚“支持的减少”是什么。但是文档暗示它是对循环体中的前一个值进行操作的二元运算符:
如果二进制函数/运算符使用循环体中的先前值更新变量,则会自动推断减少。为+= 和*= 运算符自动推断减少的初始值。对于其他函数/运算符,归约变量应在进入prange 循环之前保存标识值。标量和任意维度的数组都支持以这种方式进行归约。
OP 中的代码使用列表作为交叉迭代依赖,并在循环体中调用list.append。就我个人而言,我不会将list.append 称为缩减,并且它没有使用二元运算符,所以我的假设是它很可能不支持。至于另一个交叉迭代依赖running:它对上一次迭代的结果使用加法(这很好),但如果超过阈值(可能不太好),也会有条件地将其重置为零。
Numba 提供了检查中间代码(LLVM 和 ASM)代码的方法:
dynamic_cumsum.inspect_types()
dynamic_cumsum.inspect_llvm()
dynamic_cumsum.inspect_asm()
但是,即使我对结果有必要的理解,可以就发出的代码的正确性做出任何陈述——通常,“证明”多线程/进程代码正常工作是非常重要的。鉴于我什至缺乏 LLVM 和 ASM 知识,甚至看不到它是否尝试并行化它,我实际上无法回答您的具体问题,它会产生什么结果。
回到代码,如前所述,如果我使用 parallel=True,它会引发异常(不支持缩减),因此我假设 numba 不会并行化示例中的任何内容:
from numba import njit, prange
@njit(parallel=True)
def dynamic_cumsum(seq, index, max_value):
cumsum = []
running = 0
for i in prange(len(seq)):
if running > max_value:
cumsum.append([index[i], running])
running = 0
running += seq[i]
cumsum.append([index[-1], running])
return cumsum
dynamic_cumsum(np.ones(100), np.arange(100), 10)
AssertionError: Invalid reduction format
During handling of the above exception, another exception occurred:
LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
Invalid reduction format
File "<>", line 7:
def dynamic_cumsum(seq, index, max_value):
<source elided>
running = 0
for i in prange(len(seq)):
^
[1] During: lowering "id=2[LoopNest(index_variable = parfor_index.192, range = (0, seq_size0.189, 1))]{56: <ir.Block at <> (10)>, 24: <ir.Block at <> (7)>, 34: <ir.Block at <> (8)>}Var(parfor_index.192, <> (7))" at <> (7)
剩下要说的是:prange 与普通的range 相比,在这种情况下没有提供任何速度优势(因为它不是并行执行的)。因此,在这种情况下,我不会“冒险”潜在问题和/或让读者感到困惑 - 鉴于 numba 文档不支持它。
from numba import njit, prange
@njit
def p_dynamic_cumsum(seq, index, max_value):
cumsum = []
running = 0
for i in prange(len(seq)):
if running > max_value:
cumsum.append([index[i], running])
running = 0
running += seq[i]
cumsum.append([index[-1], running])
return cumsum
@njit
def dynamic_cumsum(seq, index, max_value):
cumsum = []
running = 0
for i in range(len(seq)): # <-- here is the only change
if running > max_value:
cumsum.append([index[i], running])
running = 0
running += seq[i]
cumsum.append([index[-1], running])
return cumsum
只是一个支持我之前提出的“不快于”声明的快速时机:
import numpy as np
seq = np.random.randint(0, 100, 10_000_000)
index = np.arange(10_000_000)
max_ = 500
# Correctness and warm-up
assert p_dynamic_cumsum(seq, index, max_) == dynamic_cumsum(seq, index, max_)
%timeit p_dynamic_cumsum(seq, index, max_)
# 468 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit dynamic_cumsum(seq, index, max_)
# 470 ms ± 9.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)