【问题标题】:searching sorted items into a sorted sequence将已排序的项目搜索到已排序的序列中
【发布时间】:2015-04-16 17:05:34
【问题描述】:

我想在已排序的值数组中查找一系列项目。 我知道使用 numpy 我可以做到:

l = np.searchsorted(values, items)

这具有 O(len(items)*log(len(values))) 的复杂性。 但是,我的项目也已排序,因此我可以在 O(len(items)+len(values)) 中进行:

l = np.zeros(items.size, dtype=np.int32)
k, K = 0, len(values)
for i in range(len(items)):
    while k < K and values[k] < items[i]:
        k += 1
    l[i] = k

问题在于,由于 python 循环,纯 python 中的这个版本比 searchsorted 慢得多,即使对于大的 len(items) 和 len(values) (~10^6)。

知道如何用 numpy “矢量化”这个循环吗?

【问题讨论】:

  • 你不能对 items 中的每个项目使用searchsorted,同时将值从最后找到的索引切片到末尾吗?也许这会加快速度。
  • 如果值中不存在 items[i],您的方法会正常工作吗?我认为它表示第一个值大于 items[i]
  • 即使你对它进行矢量化,你仍然会在 python 的土地上......而且 numpy 会把你吹出水面
  • @JoranBeasley 通过“矢量化”我的意思是使用 numpy 原语。
  • @PeterE 建议的方法已经内置在 numpy 1.9 中,this 是相关的 PR。

标签: python performance numpy big-o binary-search


【解决方案1】:

一些示例数据:

# some example data
np.random.seed(0)
n_values = 1000000
n_items = 100000
values = np.random.rand(n_values)
items = np.random.rand(n_items)
values.sort()
items.sort()

您的原始代码 sn-p 以及 @PeterE 建议的实现:

def original(values, items):
    l = np.empty(items.size, dtype=np.int32)
    k, K = 0, len(values)
    for i, item in enumerate(items):
        while k < K and values[k] < item:
            k += 1
        l[i] = k
    return l

def peter_e(values, items):
    l = np.empty(items.size, dtype=np.int32)
    last_idx = 0
    for i, item in enumerate(items):
        last_idx += values[last_idx:].searchsorted(item)
        l[i] = last_idx
    return l

针对 naive np.searchsorted 测试正确性:

ss = values.searchsorted(items)

print(all(original(values, items) == ss))
# True

print(all(peter_e(values, items) == ss))
# True

时间安排:

In [1]: %timeit original(values, items)
10 loops, best of 3: 115 ms per loop

In [2]: %timeit peter_e(values, items)
10 loops, best of 3: 79.8 ms per loop

In [3]: %timeit values.searchsorted(items)
100 loops, best of 3: 4.09 ms per loop

因此,对于这种大小的输入,天真的使用 np.searchsorted 轻松击败您的原始代码以及 PeterE 的建议。

更新

为避免任何可能导致时序偏差的缓存效应,我们可以为每次基准测试迭代生成一组新的随机输入数组:

In [1]: %%timeit values = np.random.randn(n_values); items = np.random.randn(n_items); values.sort(); items.sort();
original(values, items)
   .....: 
10 loops, best of 3: 115 ms per loop

In [2]: %%timeit values = np.random.randn(n_values); items = np.random.randn(n_items); values.sort(); items.sort();
peter_e(values, items)
   .....: 
10 loops, best of 3: 79.9 ms per loop

In [3]: %%timeit values = np.random.randn(n_values); items = np.random.randn(n_items); values.sort(); items.sort();
values.searchsorted(items)
   .....: 
100 loops, best of 3: 4.08 ms per loop

更新 2

valuesitems 都被排序的情况下,编写一个能击败np.searchsorted 的Cython 函数并不难。

search_doubly_sorted.pyx:

import numpy as np
cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
def search_doubly_sorted(values, items):

    cdef:
        double[:] _values = values.astype(np.double)
        double[:] _items = items.astype(np.double)
        long n_items = items.shape[0]
        long n_values = values.shape[0]
        long[:] out = np.empty(n_items, dtype=np.int64)
        long ii, jj, last_idx

    last_idx = 0
    for ii in range(n_items):
        for jj in range(last_idx, n_values):
             if _items[ii] <= _values[jj]:
                break
        last_idx = jj
        out[ii] = last_idx

    return out.base

正确性测试:

In [1]: from search_doubly_sorted import search_doubly_sorted

In [2]: print(all(search_doubly_sorted(values, items) == values.searchsorted(items)))                     
# True

基准测试:

In [3]: %timeit values.searchsorted(items)
100 loops, best of 3: 4.07 ms per loop

In [4]: %timeit search_doubly_sorted(values, items)
1000 loops, best of 3: 1.44 ms per loop

不过,性能提升相当有限。除非这是您代码中的严重瓶颈,否则您应该坚持使用np.searchsorted

【讨论】:

  • 不应该是l[i] = l[i-1] + values[l[i-1]:].searchsorted(item)吗?您必须从最后找到的索引处开始切片。
  • 另外:你能对它进行计时(并且可能验证)而不是简单地使用values.searchsort(items)吗?
  • 感谢这些时间安排,但我认为要看到显着差异,数组的大小应该真的更大(然后 np.searchsorted 基线真的很快)。我的问题确实是 python 开销(隐藏在 big-o 的常数因子中)非常大。
  • @PeterE 只要没有关系,我的版本就可以运行,但你的版本更强大。我会更新我的代码。
  • 我已经分析过了,这是瓶颈。我正在处理非常大的数组,因此对于我而言,大 O 复杂性不仅是理论上的。 cythonized 版本真的要快得多。
猜你喜欢
  • 1970-01-01
  • 2011-07-10
  • 2014-01-22
  • 1970-01-01
  • 1970-01-01
  • 2016-05-17
  • 2012-12-24
  • 2020-12-20
  • 2012-11-27
相关资源
最近更新 更多