【问题标题】:Python/NumPy first occurrence of subarrayPython/NumPy 第一次出现子数组
【发布时间】:2023-03-12 13:45:01
【问题描述】:

在 Python 或 NumPy 中,找出第一次出现的子数组的最佳方法是什么?

例如,我有

a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]

找出 b 在 a 中的位置的最快方法(运行时)是什么?我知道对于字符串来说这非常容易,但是对于列表或 numpy ndarray 呢?

非常感谢!

[已编辑] 我更喜欢 numpy 解决方案,因为根据我的经验,numpy 向量化比 Python 列表理解要快得多。同时,大数组很大,所以我不想将其转换为字符串;那将是(太)长。

【问题讨论】:

  • 您可以将列表转换为字符串进行比较吗? x=''.join(str(x) for x in a) 然后对结果字符串使用 find 方法?还是他们必须保留列表?

标签: python numpy arrays


【解决方案1】:

我假设您正在寻找特定于 numpy 的解决方案,而不是简单的列表理解或 for 循环。一种直接的方法是使用rolling window 技术来搜索适当大小的窗口。

这种方法很简单,可以正常工作,并且比任何纯 Python 解决方案都快得多。对于许多用例来说应该足够了。然而,由于多种原因,这并不是最有效的方法。对于更复杂但在预期情况下渐近最优的方法,请参阅norok2's answer 中基于numbarolling hash 实现。

这是 rolling_window 函数:

>>> def rolling_window(a, size):
...     shape = a.shape[:-1] + (a.shape[-1] - size + 1, size)
...     strides = a.strides + (a. strides[-1],)
...     return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
... 

然后你可以做类似的事情

>>> a = numpy.arange(10)
>>> numpy.random.shuffle(a)
>>> a
array([7, 3, 6, 8, 4, 0, 9, 2, 1, 5])
>>> rolling_window(a, 3) == [8, 4, 0]
array([[False, False, False],
       [False, False, False],
       [False, False, False],
       [ True,  True,  True],
       [False, False, False],
       [False, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)

要使其真正有用,您必须使用 all 沿轴 1 减少它:

>>> numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
array([False, False, False,  True, False, False, False, False], dtype=bool)

然后你可以使用它,但是你会使用一个布尔数组。获取索引的简单方法:

>>> bool_indices = numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
>>> numpy.mgrid[0:len(bool_indices)][bool_indices]
array([3])

对于列表,您可以调整其中一个 rolling window 迭代器以使用类似的方法。

对于非常大的数组和子数组,你可以这样节省内存:

>>> windows = rolling_window(a, 3)
>>> sub = [8, 4, 0]
>>> hits = numpy.ones((len(a) - len(sub) + 1,), dtype=bool)
>>> for i, x in enumerate(sub):
...     hits &= numpy.in1d(windows[:,i], [x])
... 
>>> hits
array([False, False, False,  True, False, False, False, False], dtype=bool)
>>> hits.nonzero()
(array([3]),)

另一方面,这可能会慢一些。

【讨论】:

  • 这种方法的问题是,虽然rolling_window 的返回不需要任何新的内存,并重用原始数组的内存,但在执行== 操作时,您会实例化一个新的size 乘以原始数组的完整大小的布尔数组。如果数组足够大,这可能会严重影响性能。
  • 你说得对,这不是渐近最优解。但是,它在简单性和效率之间取得了很好的平衡——它比任何纯 Python 方法都简单、正确且速度快几个数量级。对于那些需要可证明的最优解决方案的人,norok2 的非常详细的答案有几个候选者,包括基于numba 的滚动哈希方法,该方法在预期情况下是渐近最优的。
【解决方案2】:

以下代码应该可以工作:

[x for x in xrange(len(a)) if a[x:x+len(b)] == b]

返回模式开始的索引。

【讨论】:

  • 这可能不是最快的解决方案,但 +1 是最简单的答案。这可能满足许多​​用户的需求,尤其是在 numpy 不可用的情况下。
  • 在 Python 3 中使用 range 而不是 xrange
  • 为了提高性能,您可以将len(a) 替换为len(a) - len(b) + 1
【解决方案3】:

一种基于卷积的方法,应该比基于stride_tricks 的方法更节省内存:

def find_subsequence(seq, subseq):
    target = np.dot(subseq, subseq)
    candidates = np.where(np.correlate(seq,
                                       subseq, mode='valid') == target)[0]
    # some of the candidates entries may be false positives, double check
    check = candidates[:, np.newaxis] + np.arange(len(subseq))
    mask = np.all((np.take(seq, check) == subseq), axis=-1)
    return candidates[mask]

对于非常大的数组,可能无法使用stride_tricks 方法,但这个方法仍然有效:

haystack = np.random.randint(1000, size=(1e6))
needle = np.random.randint(1000, size=(100,))
# Hide 10 needles in the haystack
place = np.random.randint(1e6 - 100 + 1, size=10)
for idx in place:
    haystack[idx:idx+100] = needle

In [3]: find_subsequence(haystack, needle)
Out[3]: 
array([253824, 321497, 414169, 456777, 635055, 879149, 884282, 954848,
       961100, 973481], dtype=int64)

In [4]: np.all(np.sort(place) == find_subsequence(haystack, needle))
Out[4]: True

In [5]: %timeit find_subsequence(haystack, needle)
10 loops, best of 3: 79.2 ms per loop

【讨论】:

  • 虽然我真的很喜欢这种方法,但我应该注意,通常通过 l2 范数找到候选者并不比从针中找到特定符号更好。但是在通过计算与针长度相同的随机模式的点积进行小修改后,这种方法将非常棒。
【解决方案4】:

已编辑以包含更深入的讨论、更好的代码和更多的基准测试)


总结

对于原始速度和效率,可以使用其中一种经典算法的 Cython 或 Numba 加速版本(当输入分别是 Python 序列或 NumPy 数组时)。

推荐的方法是:

  • find_kmp_cy() 用于 Python 序列(listtuple 等)
  • find_kmp_nb() 用于 NumPy 数组

其他有效的方法是 find_rk_cy()find_rk_nb(),它们的内存效率更高,但不能保证在线性时间内运行。

如果 Cython / Numba 不可用,find_kmp()find_rk() 对于大多数用例来说都是一个很好的全方位解决方案,尽管在一般情况下和 Python 序列中,天真的方法,在某种形式下,特别是find_pivot(),可能会更快。对于 NumPy 数组,find_conv()(来自@Jaime answer)优于任何非加速的幼稚方法。

(完整代码如下,herethere。)


理论

这是计算机科学中的一个经典问题,被称为字符串搜索或字符串匹配问题。 朴素的方法基于两个嵌套循环,平均计算复杂度为O(n + m),但最坏的情况是O(n m)。 多年来,已经开发了许多alternative approaches,以保证更好的最坏情况性能。

在经典算法中,最适合通用序列的算法(因为它们不依赖于字母表)是:

最后一种算法依赖于rolling hash 的计算来提高效率,因此可能需要一些关于输入的额外知识才能获得最佳性能。 最终,它最适合同类数据,例如数字数组。 Python 中数值数组的一个值得注意的例子当然是 N​​umPy 数组。

备注

  • 朴素算法非常简单,适合在 Python 中以不同程度的运行时速度实现不同的实现。
  • 其他算法在可通过语言技巧优化的方面不太灵活。
  • Python 中的显式循环可能是速度瓶颈,可以使用多种技巧在解释器之外执行循环。
  • Cython 特别擅长加速通用 Python 代码的显式循环。
  • Numba 特别擅长加速 NumPy 数组上的显式循环。
  • 这是生成器的绝佳用例,因此所有代码都将使用这些生成器而不是常规函数。

Python 序列(listtuple 等)

基于朴素算法

  • find_loop()find_loop_cy()find_loop_nb() 分别是纯 Python、Cython 和 Numba JITing 中的显式循环实现。请注意 Numba 版本中的 forceobj=True,这是必需的,因为我们使用的是 Python 对象输入。
def find_loop(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        found = True
        for j in range(m):
            if seq[i + j] != subseq[j]:
                found = False
                break
        if found:
            yield i
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_loop_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    for i in range(n - m + 1):
        found = True
        for j in range(m):
            if seq[i + j] != subseq[j]:
                found = False
                break
        if found:
            yield i
find_loop_nb = nb.jit(find_loop, forceobj=True)
find_loop_nb.__name__ = 'find_loop_nb'
  • find_all() 在理解生成器上用 all() 替换内部循环
def find_all(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if all(seq[i + j] == subseq[j] for j in range(m)):
            yield i
  • find_slice() 切片后直接比较== 替换内循环[]
def find_slice(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i:i + m] == subseq:
            yield i
  • find_mix()find_mix2() 在切片 [] 后用直接比较 == 替换内部循环,但在第一个(和最后一个)字符上包含一个或两个额外的短路,这可能更快,因为使用 int 切片比使用slice() 切片要快得多。
def find_mix(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i] == subseq[0] and seq[i:i + m] == subseq:
            yield i
def find_mix2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i] == subseq[0] and seq[i + m - 1] == subseq[m - 1] \
                and seq[i:i + m] == subseq:
            yield i
  • find_pivot()find_pivot2() 使用子序列的第一项将外部循环替换为多个 .index() 调用,同时对内部循环使用切片,最终在最后一项上进行额外的短路(第一个匹配建造)。多个.index() 调用被包装在index_all() 生成器中(它本身可能很有用)。
def index_all(seq, item, start=0, stop=-1):
    try:
        n = len(seq)
        if n > 0:
            start %= n
            stop %= n
            i = start
            while True:
                i = seq.index(item, i)
                if i <= stop:
                    yield i
                    i += 1
                else:
                    return
        else:
            return
    except ValueError:
        pass


def find_pivot(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    for i in index_all(seq, subseq[0], 0, n - m):
        if seq[i:i + m] == subseq:
            yield i
def find_pivot2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    for i in index_all(seq, subseq[0], 0, n - m):
        if seq[i + m - 1] == subseq[m - 1] and seq[i:i + m] == subseq:
            yield i

基于 Knuth-Morris-Pratt (KMP) 算法

  • find_kmp() 是该算法的普通 Python 实现。由于没有简单的循环或可以将切片与slice() 一起使用的地方,因此除了使用 Cython 之外,没有太多需要做的优化(Numba 将再次需要 forceobj=True,这会导致代码变慢)。李>
def find_kmp(seq, subseq):
    n = len(seq)
    m = len(subseq)
    # : compute offsets
    offsets = [0] * m
    j = 1
    k = 0
    while j < m: 
        if subseq[j] == subseq[k]: 
            k += 1
            offsets[j] = k
            j += 1
        else: 
            if k != 0: 
                k = offsets[k - 1] 
            else: 
                offsets[j] = 0
                j += 1
    # : find matches
    i = j = 0
    while i < n: 
        if seq[i] == subseq[j]: 
            i += 1
            j += 1
        if j == m:
            yield i - j
            j = offsets[j - 1] 
        elif i < n and seq[i] != subseq[j]: 
            if j != 0: 
                j = offsets[j - 1] 
            else: 
                i += 1
  • find_kmp_cy() 是算法的 Cython 实现,其中索引使用 C int 数据类型,这会导致代码更快。
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_kmp_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    # : compute offsets
    offsets = [0] * m
    cdef Py_ssize_t j = 1
    cdef Py_ssize_t k = 0
    while j < m: 
        if subseq[j] == subseq[k]: 
            k += 1
            offsets[j] = k
            j += 1
        else: 
            if k != 0: 
                k = offsets[k - 1] 
            else: 
                offsets[j] = 0
                j += 1
    # : find matches
    cdef Py_ssize_t i = 0
    j = 0
    while i < n: 
        if seq[i] == subseq[j]: 
            i += 1
            j += 1
        if j == m:
            yield i - j
            j = offsets[j - 1] 
        elif i < n and seq[i] != subseq[j]: 
            if j != 0: 
                j = offsets[j - 1] 
            else: 
                i += 1

基于 Rabin-Karp (RK) 算法

  • find_rk() 是一个纯 Python 实现,它依赖 Python 的 hash() 来计算(和比较)哈希。这样的哈希是通过一个简单的sum() 滚动的。然后,通过减去刚刚访问的项目 seq[i - 1] 上的 hash() 的结果并将新考虑的项目 seq[i + m - 1] 上的 hash() 的结果相加,从先前的哈希中计算出翻转。
def find_rk(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if seq[:m] == subseq:
        yield 0
    hash_subseq = sum(hash(x) for x in subseq)  # compute hash
    curr_hash = sum(hash(x) for x in seq[:m])  # compute hash
    for i in range(1, n - m + 1):
        curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1])   # update hash
        if hash_subseq == curr_hash and seq[i:i + m] == subseq:
            yield i
  • find_rk_cy() 是算法的 Cython 实现,其中索引使用适当的 C 数据类型,这会产生更快的代码。请注意,hash() 会截断“基于主机位宽的返回值”。
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_rk_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    if seq[:m] == subseq:
        yield 0
    cdef Py_ssize_t hash_subseq = sum(hash(x) for x in subseq)  # compute hash
    cdef Py_ssize_t curr_hash = sum(hash(x) for x in seq[:m])  # compute hash
    cdef Py_ssize_t old_item, new_item
    for i in range(1, n - m + 1):
        old_item = hash(seq[i - 1])
        new_item = hash(seq[i + m - 1])
        curr_hash += new_item - old_item  # update hash
        if hash_subseq == curr_hash and seq[i:i + m] == subseq:
            yield i

基准

上述函数在两个输入上进行评估:

  • 随机输入
def gen_input(n, k=2):
    return tuple(random.randint(0, k - 1) for _ in range(n))
  • (几乎)朴素算法的最差输入
def gen_input_worst(n, k=-2):
    result = [0] * n
    result[k] = 1
    return tuple(result)

subseq 具有固定大小 (32)。 由于有很多替代方案,因此已经完成了两个单独的分组,并且省略了一些变化非常小且时间几乎相同的解决方案(即find_mix2()find_pivot2())。 对于每组,两个输入都经过测试。 对于每个基准测试,都提供了完整图和最快方法的放大图。

天真随机

最坏情况下的幼稚

其他随机

其他最差

(完整代码可用here。)


NumPy 数组

基于朴素算法

  • find_loop()find_loop_cy()find_loop_nb() 分别是纯 Python、Cython 和 Numba JITing 中的仅显式循环实现。前两个的代码与上面相同,因此省略。 find_loop_nb() 现在享受快速 JIT 编译。内部循环已编写在一个单独的函数中,因为它可以在 find_rk_nb() 中重复使用(在 Numba 函数中调用 Numba 函数不会产生 Python 典型的函数调用惩罚)。
@nb.jit
def _is_equal_nb(seq, subseq, m, i):
    for j in range(m):
        if seq[i + j] != subseq[j]:
            return False
    return True


@nb.jit
def find_loop_nb(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if _is_equal_nb(seq, subseq, m, i):
            yield i
  • find_all()和上面一样,而find_slice()find_mix()find_mix2()和上面几乎一样,唯一的区别是seq[i:i + m] == subseq现在是np.all()的参数:np.all(seq[i:i + m] == subseq)

  • find_pivot()find_pivot2() 与上面的想法相同,只是现在使用 np.where() 而不是 index_all() 并且需要在 np.all() 调用中包含数组相等性。

def find_pivot(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    max_i = n - m
    for i in np.where(seq == subseq[0])[0]:
        if i > max_i:
            return
        elif np.all(seq[i:i + m] == subseq):
            yield i


def find_pivot2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    max_i = n - m
    for i in np.where(seq == subseq[0])[0]:
        if i > max_i:
            return
        elif seq[i + m - 1] == subseq[m - 1] \
                and np.all(seq[i:i + m] == subseq):
            yield i
  • find_rolling() 通过滚动窗口表示循环,并使用np.all() 检查匹配。这以创建大型临时对象为代价对所有循环进行了矢量化,同时仍然大量应用了朴素算法。 (方法来自@senderle answer)。
def rolling_window(arr, size):
    shape = arr.shape[:-1] + (arr.shape[-1] - size + 1, size)
    strides = arr.strides + (arr.strides[-1],)
    return np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)


def find_rolling(seq, subseq):
    bool_indices = np.all(rolling_window(seq, len(subseq)) == subseq, axis=1)
    yield from np.mgrid[0:len(bool_indices)][bool_indices]
  • find_rolling2() 是上述的一种内存效率稍高的变体,其中矢量化只是部分的,并且保留了一个显式循环(沿着预期的最短维度——subseq 的长度)。 (方法也来自@senderle answer)。
def find_rolling2(seq, subseq):
    windows = rolling_window(seq, len(subseq))
    hits = np.ones((len(seq) - len(subseq) + 1,), dtype=bool)
    for i, x in enumerate(subseq):
        hits &= np.in1d(windows[:, i], [x])
    yield from hits.nonzero()[0]

基于 Knuth-Morris-Pratt (KMP) 算法

  • find_kmp() 与上面相同,而 find_kmp_nb() 是直接的 JIT 编译。
find_kmp_nb = nb.jit(find_kmp)
find_kmp_nb.__name__ = 'find_kmp_nb'

基于 Rabin-Karp (RK) 算法

  • find_rk() 与上述相同,只是seq[i:i + m] == subseq 再次包含在np.all() 调用中。

  • find_rk_nb() 是上述 Numba 加速版本。使用前面定义的 _is_equal_nb() 来确定匹配,而对于散列,它使用 Numba 加速的 sum_hash_nb() 函数,其定义非常简单。

@nb.jit
def sum_hash_nb(arr):
    result = 0
    for x in arr:
        result += hash(x)
    return result


@nb.jit
def find_rk_nb(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if _is_equal_nb(seq, subseq, m, 0):
        yield 0
    hash_subseq = sum_hash_nb(subseq)  # compute hash
    curr_hash = sum_hash_nb(seq[:m])  # compute hash
    for i in range(1, n - m + 1):
        curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1])  # update hash
        if hash_subseq == curr_hash and _is_equal_nb(seq, subseq, m, i):
            yield i
  • find_conv() 使用伪 Rabin-Karp 方法,其中初始候选者使用 np.dot() 乘积进行散列,并位于 seqsubseqnp.where() 之间的卷积上。该方法是伪的,因为虽然它仍然使用散列来识别可能的候选者,但它可能不被视为滚动散列(它取决于np.correlate() 的实际实现。另外,它需要创建一个大小为输入。(方法来自@Jaime answer)。
def find_conv(seq, subseq):
    target = np.dot(subseq, subseq)
    candidates = np.where(np.correlate(seq, subseq, mode='valid') == target)[0]
    check = candidates[:, np.newaxis] + np.arange(len(subseq))
    mask = np.all((np.take(seq, check) == subseq), axis=-1)
    yield from candidates[mask]

基准

和以前一样,上面的函数在两个输入上进行评估:

  • 随机输入
def gen_input(n, k=2):
    return np.random.randint(0, k, n)
  • (几乎)朴素算法的最差输入
def gen_input_worst(n, k=-2):
    result = np.zeros(n, dtype=int)
    result[k] = 1
    return result

subseq 具有固定大小 (32)。 该图遵循与之前相同的方案,为方便起见总结如下。

由于有很多替代方案,因此已经完成了两个单独的分组,并且省略了一些变化非常小且时间几乎相同的解决方案(即find_mix2()find_pivot2())。 对于每组,两个输入都经过测试。 对于每个基准测试,都提供了完整图和最快方法的放大图。

天真随机

最坏情况下的幼稚

其他随机

其他最差

(完整代码可用here。)

【讨论】:

  • 感谢您运行所有这些测试!我从我的回答中链接到这个。我仍然喜欢我的,因为它相当快,易于推理,并且不添加任何依赖项。但对于需要真正最佳解决方案的人来说,这很棒。不过,我不同意 KMP 方法是最好的。对于绝大多数实际用例,RK 更快,很少有人真正需要 KMP 提供的最坏情况保证。
【解决方案5】:

可以调用 tostring() 方法将数组转换为字符串,然后可以使用快速字符串搜索。当您要检查许多子数组时,此方法可能会更快。

import numpy as np

a = np.array([1,2,3,4,5,6])
b = np.array([2,3,4])
print a.tostring().index(b.tostring())//a.itemsize

【讨论】:

  • 这个解决方案非常快速和优雅,非常感谢!稍微相关的是,我有一个项目使用 SWIG 包装器从 C++ 中抓取大约 1e8 个元素的 np 数组,并且数组创建非常慢。将它们作为字符串使用可提高实时性能
  • 方法不正确。见np.array([0, 1]).tostring().index(np.array([256]).tostring())
【解决方案6】:

又一次尝试,但我确信有更多 Pythonic 和有效的方法可以做到这一点......

定义数组匹配(a,b): 对于 xrange(0, len(a)-len(b)+1) 中的 i: 如果 a[i:i+len(b)] == b: 返回我 返回无 a = [1, 2, 3, 4, 5, 6] b = [2, 3, 4] 打印数组匹配(a,b) 1

(正如 cdhowie 提到的,第一个答案不在问题范围内)

set(a) & set(b) == set(b)

【讨论】:

  • 两个问题:这也会匹配[1, 3, 2, 4, 5, 6](集合不排序;数组是),并且它不报告匹配的位置(应该是索引1)。
  • 是的,我的错,回答太快了:-/
  • 您可以通过将first_occurence=i 替换为return i 并将return first_occurence 替换为return None 来简化您的代码。
【解决方案7】:

这是一个相当直接的选项:

def first_subarray(full_array, sub_array):
    n = len(full_array)
    k = len(sub_array)
    matches = np.argwhere([np.all(full_array[start_ix:start_ix+k] == sub_array) 
                   for start_ix in range(0, n-k+1)])
    return matches[0]

然后使用我们得到的原始a、b向量:

a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]
first_subarray(a, b)
Out[44]: 
array([1], dtype=int64)

【讨论】:

  • 你可能会添加一些逻辑来处理没有匹配的情况......
【解决方案8】:

三个提议的解决方案的快速比较(随机创建的向量的平均迭代时间为 100 次。):

import time
import collections
import numpy as np


def function_1(seq, sub):
    # direct comparison
    seq = list(seq)
    sub = list(sub)
    return [i for i in range(len(seq) - len(sub)) if seq[i:i+len(sub)] == sub]

def function_2(seq, sub):
    # Jamie's solution
    target = np.dot(sub, sub)
    candidates = np.where(np.correlate(seq, sub, mode='valid') == target)[0]
    check = candidates[:, np.newaxis] + np.arange(len(sub))
    mask = np.all((np.take(seq, check) == sub), axis=-1)
    return candidates[mask]

def function_3(seq, sub):
    # HYRY solution
    return seq.tostring().index(sub.tostring())//seq.itemsize


# --- assessment time performance
N = 100

seq = np.random.choice([0, 1, 2, 3, 4, 5, 6], 3000)
sub = np.array([1, 2, 3])

tim = collections.OrderedDict()
tim.update({function_1: 0.})
tim.update({function_2: 0.})
tim.update({function_3: 0.})

for function in tim.keys():
    for _ in range(N):
        seq = np.random.choice([0, 1, 2, 3, 4], 3000)
        sub = np.array([1, 2, 3])
        start = time.time()
        function(seq, sub)
        end = time.time()
        tim[function] += end - start

timer_dict = collections.OrderedDict()
for key, val in tim.items():
    timer_dict.update({key.__name__: val / N})

print(timer_dict)

这会导致(在我的旧机器上):

OrderedDict([
('function_1', 0.0008518099784851074), 
('function_2', 8.157730102539063e-05), 
('function_3', 6.124973297119141e-06)
])

【讨论】:

    【解决方案9】:

    首先,将列表转换为字符串。

    a = ''.join(str(i) for i in a)
    b = ''.join(str(i) for i in b)
    

    转换成字符串后,可以通过下面的字符串函数轻松找到子字符串的索引。

    a.index(b)
    

    干杯!!

    【讨论】:

      猜你喜欢
      • 2018-06-30
      • 2019-01-07
      • 1970-01-01
      • 2018-11-14
      • 1970-01-01
      • 2017-08-23
      • 2015-09-01
      • 2019-12-20
      • 2017-09-14
      相关资源
      最近更新 更多