【问题标题】:Efficiently return the index of the first value satisfying condition in array有效地返回数组中第一个满足条件的值的索引
【发布时间】:2019-03-31 22:11:44
【问题描述】:

我需要在满足条件的一维 NumPy 数组或 Pandas 数字系列中找到第一个值的索引。数组很大,索引可能靠近数组的开始结束,条件可能根本不满足。我无法提前判断哪个更有可能。如果不满足条件,则返回值应为-1。我考虑了几种方法。

尝试 1

# func(arr) returns a Boolean array
idx = next(iter(np.where(func(arr))[0]), -1)

但这通常太慢了,因为func(arr)整个 数组上应用了矢量化函数,而不是在满足条件时停止。具体来说,当条件在数组的开始附近满足时,它是昂贵的。

尝试 2

np.argmax 稍微快一些,但无法识别何时从未满足条件:

np.random.seed(0)
arr = np.random.rand(10**7)

assert next(iter(np.where(arr > 0.999999)[0]), -1) == np.argmax(arr > 0.999999)

%timeit next(iter(np.where(arr > 0.999999)[0]), -1)  # 21.2 ms
%timeit np.argmax(arr > 0.999999)                    # 17.7 ms

np.argmax(arr > 1.0) 返回0,即条件满足时的实例。

尝试 3

# func(arr) returns a Boolean scalar
idx = next((idx for idx, val in enumerate(arr) if func(arr)), -1)

但是当在数组的end附近满足条件时,这太慢了。这可能是因为生成器表达式因大量__next__ 调用而产生了昂贵的开销。

总是是一种妥协还是有办法,对于通用func,有效地提取第一个索引?

基准测试

对于基准测试,假设func 在值大于给定常数时找到索引:

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
import numpy as np

np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999

# Start of array benchmark
%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs

# End of array benchmark
%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms

【问题讨论】:

    标签: python arrays pandas performance numpy


    【解决方案1】:

    numba

    使用numba,可以优化两个场景。从语法上讲,您只需要使用简单的for 循环构造一个函数:

    from numba import njit
    
    @njit
    def get_first_index_nb(A, k):
        for i in range(len(A)):
            if A[i] > k:
                return i
        return -1
    
    idx = get_first_index_nb(A, 0.9)
    

    Numba 通过 JIT(“及时”)编译代码和利用 CPU-level optimisations 来提高性能。没有@njit 装饰器的常规 for 循环通常会 比您已经尝试过的条件迟到的情况。 p>

    对于 Pandas 数字系列 df['data'],您可以简单地将 NumPy 表示提供给 JIT 编译的函数:

    idx = get_first_index_nb(df['data'].values, 0.9)
    

    概括

    由于numba 允许functions as arguments,并且假设传递的函数也可以进行 JIT 编译,您可以得出一种方法来计算满足条件的第 n 个索引任意func

    @njit
    def get_nth_index_count(A, func, count):
        c = 0
        for i in range(len(A)):
            if func(A[i]):
                c += 1
                if c == count:
                    return i
        return -1
    
    @njit
    def func(val):
        return val > 0.9
    
    # get index of 3rd value where func evaluates to True
    idx = get_nth_index_count(arr, func, 3)
    

    对于第三个 last 值,您可以提供相反的值,arr[::-1],并否定来自 len(arr) - 1 的结果,- 1 是考虑 0 索引所必需的。

    性能基准测试

    # Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
    
    np.random.seed(0)
    arr = np.random.rand(10**7)
    m = 0.9
    n = 0.999999
    
    @njit
    def get_first_index_nb(A, k):
        for i in range(len(A)):
            if A[i] > k:
                return i
        return -1
    
    def get_first_index_np(A, k):
        for i in range(len(A)):
            if A[i] > k:
                return i
        return -1
    
    %timeit get_first_index_nb(arr, m)                                 # 375 ns
    %timeit get_first_index_np(arr, m)                                 # 2.71 µs
    %timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms
    %timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs
    
    %timeit get_first_index_nb(arr, n)                                 # 204 µs
    %timeit get_first_index_np(arr, n)                                 # 44.8 ms
    %timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms
    %timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms
    

    【讨论】:

      【解决方案2】:

      我也想做类似的事情,发现这个问题中提出的解决方案并没有真正帮助我。特别是,numba 解决方案对我来说比问题本身中提出的更传统的方法要慢得多。我有一个times_all 列表,通常大约有数万个元素,并且想要找到大于time_eventtimes_all 的第一个元素的索引。我有成千上万的time_events。我的解决方案是将times_all划分为例如100个元素的块,首先确定time_event属于哪个时间段,保留该段的第一个元素的索引,然后找到该段中的哪个索引,并将两者相加指数。这是一个最小的代码。对我来说,它的运行速度比本页中的其他解决方案快几个数量级。

      def event_time_2_index(time_event, times_all, STEPS=100):
          import numpy as np
          time_indices_jumps = np.arange(0, len(times_all), STEPS)
          time_list_jumps = [times_all[idx] for idx in time_indices_jumps]
      
          time_list_jumps_idx = next((idx for idx, val in enumerate(time_list_jumps)\
                                if val > time_event), -1)
          index_in_jumps = time_indices_jumps[time_list_jumps_idx-1]
          times_cropped = times_all[index_in_jumps:]
          event_index_rel = next((idx for idx, val in enumerate(times_cropped) \
                            if val > time_event), -1)
      
          event_index = event_index_rel + index_in_jumps
          return event_index
      

      【讨论】:

      • 您能否提供一些示例输入来演示如何更快?我很惊讶(除非在很早满足条件的特定情况下)生成器表达式会很有效。您使用next + 生成器表达式的逻辑本质上是我的尝试#3。
      • 我处理的数据是实验数据,现阶段我无法分享它们。但我有一个排序的时间步长数组,步长为 1/320 秒,大约为 1/320 秒。 1e5 个样本和另一个事件时间数组,通常为数千个。我需要这些事件的索引,以便在 EEG 分析工具中使用。使用这种分割技巧,对于 1e5 样本,最大比较次数是 1000+100,但如果没有这种分割,最多可以是 1e5-1。我使用了next 生成器,因为在你所做的基准测试中是最快的,而且它只有一行。
      • 而且,对我来说,numba 函数比其他解决方案慢,这不是我所期望的。虽然我应该说我在 Spyder 上运行我的代码,但我知道这在内存管理方面真的很糟糕,所以也许这起到了作用:stackoverflow.com/questions/57409470/…
      • I have a sorted array of time steps - 这是一个额外的假设,不能从问题中假设。我知道您要解决这个问题,但因此我相信您的回答可能是对 不同 问题的好回答。如果您要使用附加标准编写自己的问答,它可能会得到更好的接收。 [尽管您应该模拟示例输入数据,就像我在问答中所做的那样。]
      • 我在谷歌上搜索了我的问题,被引导到这个问答环节,解决方案没有帮助,我想到了一个想法,它帮助我在几小时而不是几天内完成我想做的事情,并想与任何可能被带到虚拟世界这个角落的人分享这个想法。如果对别人有帮助,那很好,但如果它不受欢迎,我不给一只会飞的火烈鸟!
      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-11-30
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多