(已编辑以包含更深入的讨论、更好的代码和更多的基准测试)
总结
对于原始速度和效率,可以使用其中一种经典算法的 Cython 或 Numba 加速版本(当输入分别是 Python 序列或 NumPy 数组时)。
推荐的方法是:
-
find_kmp_cy() 用于 Python 序列(list、tuple 等)
-
find_kmp_nb() 用于 NumPy 数组
其他有效的方法是 find_rk_cy() 和 find_rk_nb(),它们的内存效率更高,但不能保证在线性时间内运行。
如果 Cython / Numba 不可用,find_kmp() 和 find_rk() 对于大多数用例来说都是一个很好的全方位解决方案,尽管在一般情况下和 Python 序列中,天真的方法,在某种形式下,特别是find_pivot(),可能会更快。对于 NumPy 数组,find_conv()(来自@Jaime answer)优于任何非加速的幼稚方法。
(完整代码如下,here和there。)
理论
这是计算机科学中的一个经典问题,被称为字符串搜索或字符串匹配问题。
朴素的方法基于两个嵌套循环,平均计算复杂度为O(n + m),但最坏的情况是O(n m)。
多年来,已经开发了许多alternative approaches,以保证更好的最坏情况性能。
在经典算法中,最适合通用序列的算法(因为它们不依赖于字母表)是:
最后一种算法依赖于rolling hash 的计算来提高效率,因此可能需要一些关于输入的额外知识才能获得最佳性能。
最终,它最适合同类数据,例如数字数组。
Python 中数值数组的一个值得注意的例子当然是 NumPy 数组。
备注
- 朴素算法非常简单,适合在 Python 中以不同程度的运行时速度实现不同的实现。
- 其他算法在可通过语言技巧优化的方面不太灵活。
- Python 中的显式循环可能是速度瓶颈,可以使用多种技巧在解释器之外执行循环。
-
Cython 特别擅长加速通用 Python 代码的显式循环。
-
Numba 特别擅长加速 NumPy 数组上的显式循环。
- 这是生成器的绝佳用例,因此所有代码都将使用这些生成器而不是常规函数。
Python 序列(list、tuple 等)
基于朴素算法
-
find_loop()、find_loop_cy() 和 find_loop_nb() 分别是纯 Python、Cython 和 Numba JITing 中的显式循环实现。请注意 Numba 版本中的 forceobj=True,这是必需的,因为我们使用的是 Python 对象输入。
def find_loop(seq, subseq):
n = len(seq)
m = len(subseq)
for i in range(n - m + 1):
found = True
for j in range(m):
if seq[i + j] != subseq[j]:
found = False
break
if found:
yield i
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True
def find_loop_cy(seq, subseq):
cdef Py_ssize_t n = len(seq)
cdef Py_ssize_t m = len(subseq)
for i in range(n - m + 1):
found = True
for j in range(m):
if seq[i + j] != subseq[j]:
found = False
break
if found:
yield i
find_loop_nb = nb.jit(find_loop, forceobj=True)
find_loop_nb.__name__ = 'find_loop_nb'
-
find_all() 在理解生成器上用 all() 替换内部循环
def find_all(seq, subseq):
n = len(seq)
m = len(subseq)
for i in range(n - m + 1):
if all(seq[i + j] == subseq[j] for j in range(m)):
yield i
-
find_slice() 切片后直接比较== 替换内循环[]
def find_slice(seq, subseq):
n = len(seq)
m = len(subseq)
for i in range(n - m + 1):
if seq[i:i + m] == subseq:
yield i
-
find_mix() 和 find_mix2() 在切片 [] 后用直接比较 == 替换内部循环,但在第一个(和最后一个)字符上包含一个或两个额外的短路,这可能更快,因为使用 int 切片比使用slice() 切片要快得多。
def find_mix(seq, subseq):
n = len(seq)
m = len(subseq)
for i in range(n - m + 1):
if seq[i] == subseq[0] and seq[i:i + m] == subseq:
yield i
def find_mix2(seq, subseq):
n = len(seq)
m = len(subseq)
for i in range(n - m + 1):
if seq[i] == subseq[0] and seq[i + m - 1] == subseq[m - 1] \
and seq[i:i + m] == subseq:
yield i
-
find_pivot() 和 find_pivot2() 使用子序列的第一项将外部循环替换为多个 .index() 调用,同时对内部循环使用切片,最终在最后一项上进行额外的短路(第一个匹配建造)。多个.index() 调用被包装在index_all() 生成器中(它本身可能很有用)。
def index_all(seq, item, start=0, stop=-1):
try:
n = len(seq)
if n > 0:
start %= n
stop %= n
i = start
while True:
i = seq.index(item, i)
if i <= stop:
yield i
i += 1
else:
return
else:
return
except ValueError:
pass
def find_pivot(seq, subseq):
n = len(seq)
m = len(subseq)
if m > n:
return
for i in index_all(seq, subseq[0], 0, n - m):
if seq[i:i + m] == subseq:
yield i
def find_pivot2(seq, subseq):
n = len(seq)
m = len(subseq)
if m > n:
return
for i in index_all(seq, subseq[0], 0, n - m):
if seq[i + m - 1] == subseq[m - 1] and seq[i:i + m] == subseq:
yield i
基于 Knuth-Morris-Pratt (KMP) 算法
-
find_kmp() 是该算法的普通 Python 实现。由于没有简单的循环或可以将切片与slice() 一起使用的地方,因此除了使用 Cython 之外,没有太多需要做的优化(Numba 将再次需要 forceobj=True,这会导致代码变慢)。李>
def find_kmp(seq, subseq):
n = len(seq)
m = len(subseq)
# : compute offsets
offsets = [0] * m
j = 1
k = 0
while j < m:
if subseq[j] == subseq[k]:
k += 1
offsets[j] = k
j += 1
else:
if k != 0:
k = offsets[k - 1]
else:
offsets[j] = 0
j += 1
# : find matches
i = j = 0
while i < n:
if seq[i] == subseq[j]:
i += 1
j += 1
if j == m:
yield i - j
j = offsets[j - 1]
elif i < n and seq[i] != subseq[j]:
if j != 0:
j = offsets[j - 1]
else:
i += 1
-
find_kmp_cy() 是算法的 Cython 实现,其中索引使用 C int 数据类型,这会导致代码更快。
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True
def find_kmp_cy(seq, subseq):
cdef Py_ssize_t n = len(seq)
cdef Py_ssize_t m = len(subseq)
# : compute offsets
offsets = [0] * m
cdef Py_ssize_t j = 1
cdef Py_ssize_t k = 0
while j < m:
if subseq[j] == subseq[k]:
k += 1
offsets[j] = k
j += 1
else:
if k != 0:
k = offsets[k - 1]
else:
offsets[j] = 0
j += 1
# : find matches
cdef Py_ssize_t i = 0
j = 0
while i < n:
if seq[i] == subseq[j]:
i += 1
j += 1
if j == m:
yield i - j
j = offsets[j - 1]
elif i < n and seq[i] != subseq[j]:
if j != 0:
j = offsets[j - 1]
else:
i += 1
基于 Rabin-Karp (RK) 算法
-
find_rk() 是一个纯 Python 实现,它依赖 Python 的 hash() 来计算(和比较)哈希。这样的哈希是通过一个简单的sum() 滚动的。然后,通过减去刚刚访问的项目 seq[i - 1] 上的 hash() 的结果并将新考虑的项目 seq[i + m - 1] 上的 hash() 的结果相加,从先前的哈希中计算出翻转。
def find_rk(seq, subseq):
n = len(seq)
m = len(subseq)
if seq[:m] == subseq:
yield 0
hash_subseq = sum(hash(x) for x in subseq) # compute hash
curr_hash = sum(hash(x) for x in seq[:m]) # compute hash
for i in range(1, n - m + 1):
curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1]) # update hash
if hash_subseq == curr_hash and seq[i:i + m] == subseq:
yield i
-
find_rk_cy() 是算法的 Cython 实现,其中索引使用适当的 C 数据类型,这会产生更快的代码。请注意,hash() 会截断“基于主机位宽的返回值”。
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True
def find_rk_cy(seq, subseq):
cdef Py_ssize_t n = len(seq)
cdef Py_ssize_t m = len(subseq)
if seq[:m] == subseq:
yield 0
cdef Py_ssize_t hash_subseq = sum(hash(x) for x in subseq) # compute hash
cdef Py_ssize_t curr_hash = sum(hash(x) for x in seq[:m]) # compute hash
cdef Py_ssize_t old_item, new_item
for i in range(1, n - m + 1):
old_item = hash(seq[i - 1])
new_item = hash(seq[i + m - 1])
curr_hash += new_item - old_item # update hash
if hash_subseq == curr_hash and seq[i:i + m] == subseq:
yield i
基准
上述函数在两个输入上进行评估:
def gen_input(n, k=2):
return tuple(random.randint(0, k - 1) for _ in range(n))
def gen_input_worst(n, k=-2):
result = [0] * n
result[k] = 1
return tuple(result)
subseq 具有固定大小 (32)。
由于有很多替代方案,因此已经完成了两个单独的分组,并且省略了一些变化非常小且时间几乎相同的解决方案(即find_mix2() 和find_pivot2())。
对于每组,两个输入都经过测试。
对于每个基准测试,都提供了完整图和最快方法的放大图。
天真随机
最坏情况下的幼稚
其他随机
其他最差
(完整代码可用here。)
NumPy 数组
基于朴素算法
-
find_loop()、find_loop_cy() 和 find_loop_nb() 分别是纯 Python、Cython 和 Numba JITing 中的仅显式循环实现。前两个的代码与上面相同,因此省略。 find_loop_nb() 现在享受快速 JIT 编译。内部循环已编写在一个单独的函数中,因为它可以在 find_rk_nb() 中重复使用(在 Numba 函数中调用 Numba 函数不会产生 Python 典型的函数调用惩罚)。
@nb.jit
def _is_equal_nb(seq, subseq, m, i):
for j in range(m):
if seq[i + j] != subseq[j]:
return False
return True
@nb.jit
def find_loop_nb(seq, subseq):
n = len(seq)
m = len(subseq)
for i in range(n - m + 1):
if _is_equal_nb(seq, subseq, m, i):
yield i
find_all()和上面一样,而find_slice()、find_mix()和find_mix2()和上面几乎一样,唯一的区别是seq[i:i + m] == subseq现在是np.all()的参数:np.all(seq[i:i + m] == subseq)。
find_pivot() 和 find_pivot2() 与上面的想法相同,只是现在使用 np.where() 而不是 index_all() 并且需要在 np.all() 调用中包含数组相等性。
def find_pivot(seq, subseq):
n = len(seq)
m = len(subseq)
if m > n:
return
max_i = n - m
for i in np.where(seq == subseq[0])[0]:
if i > max_i:
return
elif np.all(seq[i:i + m] == subseq):
yield i
def find_pivot2(seq, subseq):
n = len(seq)
m = len(subseq)
if m > n:
return
max_i = n - m
for i in np.where(seq == subseq[0])[0]:
if i > max_i:
return
elif seq[i + m - 1] == subseq[m - 1] \
and np.all(seq[i:i + m] == subseq):
yield i
-
find_rolling() 通过滚动窗口表示循环,并使用np.all() 检查匹配。这以创建大型临时对象为代价对所有循环进行了矢量化,同时仍然大量应用了朴素算法。 (方法来自@senderle answer)。
def rolling_window(arr, size):
shape = arr.shape[:-1] + (arr.shape[-1] - size + 1, size)
strides = arr.strides + (arr.strides[-1],)
return np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)
def find_rolling(seq, subseq):
bool_indices = np.all(rolling_window(seq, len(subseq)) == subseq, axis=1)
yield from np.mgrid[0:len(bool_indices)][bool_indices]
-
find_rolling2() 是上述的一种内存效率稍高的变体,其中矢量化只是部分的,并且保留了一个显式循环(沿着预期的最短维度——subseq 的长度)。 (方法也来自@senderle answer)。
def find_rolling2(seq, subseq):
windows = rolling_window(seq, len(subseq))
hits = np.ones((len(seq) - len(subseq) + 1,), dtype=bool)
for i, x in enumerate(subseq):
hits &= np.in1d(windows[:, i], [x])
yield from hits.nonzero()[0]
基于 Knuth-Morris-Pratt (KMP) 算法
-
find_kmp() 与上面相同,而 find_kmp_nb() 是直接的 JIT 编译。
find_kmp_nb = nb.jit(find_kmp)
find_kmp_nb.__name__ = 'find_kmp_nb'
基于 Rabin-Karp (RK) 算法
@nb.jit
def sum_hash_nb(arr):
result = 0
for x in arr:
result += hash(x)
return result
@nb.jit
def find_rk_nb(seq, subseq):
n = len(seq)
m = len(subseq)
if _is_equal_nb(seq, subseq, m, 0):
yield 0
hash_subseq = sum_hash_nb(subseq) # compute hash
curr_hash = sum_hash_nb(seq[:m]) # compute hash
for i in range(1, n - m + 1):
curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1]) # update hash
if hash_subseq == curr_hash and _is_equal_nb(seq, subseq, m, i):
yield i
-
find_conv() 使用伪 Rabin-Karp 方法,其中初始候选者使用 np.dot() 乘积进行散列,并位于 seq 和 subseq 与 np.where() 之间的卷积上。该方法是伪的,因为虽然它仍然使用散列来识别可能的候选者,但它可能不被视为滚动散列(它取决于np.correlate() 的实际实现。另外,它需要创建一个大小为输入。(方法来自@Jaime answer)。
def find_conv(seq, subseq):
target = np.dot(subseq, subseq)
candidates = np.where(np.correlate(seq, subseq, mode='valid') == target)[0]
check = candidates[:, np.newaxis] + np.arange(len(subseq))
mask = np.all((np.take(seq, check) == subseq), axis=-1)
yield from candidates[mask]
基准
和以前一样,上面的函数在两个输入上进行评估:
def gen_input(n, k=2):
return np.random.randint(0, k, n)
def gen_input_worst(n, k=-2):
result = np.zeros(n, dtype=int)
result[k] = 1
return result
subseq 具有固定大小 (32)。
该图遵循与之前相同的方案,为方便起见总结如下。
由于有很多替代方案,因此已经完成了两个单独的分组,并且省略了一些变化非常小且时间几乎相同的解决方案(即find_mix2() 和find_pivot2())。
对于每组,两个输入都经过测试。
对于每个基准测试,都提供了完整图和最快方法的放大图。
天真随机
最坏情况下的幼稚
其他随机
其他最差
(完整代码可用here。)