【问题标题】:Get groups of consecutive elements of a NumPy array based on condition根据条件获取 NumPy 数组的连续元素组
【发布时间】:2019-07-05 21:58:16
【问题描述】:

我有一个如下的 NumPy 数组:

import numpy as np
a = np.array([1, 4, 2, 6, 4, 4, 6, 2, 7, 6, 2, 8, 9, 3, 6, 3, 4, 4, 5, 8])

还有一个常数b = 6

基于previous question,我可以计算c 的数量,该数量由a 中的元素连续2 次或多次小于b 的次数定义。

from itertools import groupby
b = 6
sum(len(list(g))>=2 for i, g in groupby(a < b) if i)

所以在这个例子中c == 3

现在我想在每次满足条件时输出一个数组,而不是计算满足条件的次数。

所以对于这个例子,正确的输出应该是:

array1 = [1, 4, 2]
array2 = [4, 4]
array3 = [3, 4, 4, 5]

因为:

1, 4, 2, 6, 4, 4, 6, 2, 7, 6, 2, 8, 9, 3, 6, 3, 4, 4, 5, 8  # numbers in a
1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0  # (a<b)
^^^^^^^-----^^^^-----------------------------^^^^^^^^^^---  # (a<b) 2+ times consecutively
   1         2                                    3

到目前为止,我尝试了不同的选择:

np.isin((len(list(g))>=2 for i, g in groupby(a < b)if i), a)

np.extract((len(list(g))>=2 for i, g in groupby(a < b)if i), a)

但他们都没有达到我正在寻找的目标。有人可以为我指出正确的 Python 工具以输出满足我条件的不同数组吗?

【问题讨论】:

    标签: python arrays numpy grouping


    【解决方案1】:

    在测量my other answer 的性能时,我注意到虽然它比Austin's solution 快(对于长度

    基于this answer,我使用np.split 提出了以下解决方案,这比之前在此处添加的两个答案都更有效:

    array = np.append(a, -np.inf)  # padding so we don't lose last element
    mask = array >= 6  # values to be removed
    split_indices = np.where(mask)[0]
    for subarray in np.split(array, split_indices + 1):
        if len(subarray) > 2:
            print(subarray[:-1])
    

    给予:

    [1. 4. 2.]
    [4. 4.]
    [3. 4. 4. 5.]
    

    性能*:

    *由perfplot测量

    【讨论】:

      【解决方案2】:

      使用groupby 抓取群组:

      from itertools import groupby
      
      lst = []
      b = 6
      for i, g in groupby(a, key=lambda x: x < b):
          grp = list(g)
          if i and len(grp) >= 2:
              lst.append(grp)
      
      print(lst)
      
      # [[1, 4, 2], [4, 4], [3, 4, 4, 5]]
      

      【讨论】:

        【解决方案3】:

        此任务与image labeling 非常相似,但在您的情况下,它是一维的。 SciPy 库为我们可以在这里使用的图像处理提供了一些有用的功能:

        import numpy as np
        from scipy.ndimage import (binary_dilation,
                                   binary_erosion,
                                   label)
        
        a = np.array([1, 4, 2, 6, 4, 4, 6, 2, 7, 6, 2, 8, 9, 3, 6, 3, 4, 4, 5, 8])
        b = 6  # your threshold
        min_consequent_count = 2
        
        mask = a < b
        structure = [False] + [True] * min_consequent_count  # used for erosion and dilation
        eroded = binary_erosion(mask, structure)
        dilated = binary_dilation(eroded, structure)
        labeled_array, labels_count = label(dilated)  # labels_count == c
        
        for label_number in range(1, labels_count + 1):  # labeling starts from 1
            subarray = a[labeled_array == label_number]
            print(subarray)
        

        给予:

        [1 4 2]
        [4 4]
        [3 4 4 5]
        

        说明:

        1. mask = a &lt; b 在元素小于阈值b 时返回带有True 值的boolean array

          array([ True,  True,  True, False,  True,  True, False,  True, False,
                 False,  True, False, False,  True, False,  True,  True,  True,
                  True, False])
          
        2. 如您所见,结果包含一些 True 元素,它们周围没有任何其他 True 邻居。为了消除它们,我们可以使用binary erosion。我为此使用scipy.ndimage.binary_erosion。它的默认structure 参数不适合我们的需要,因为它还会删除两个后续的True 值,所以我自己构建:

          >>> structure = [False] + [True] * min_consequent_count
          >>> structure
          [False, True, True]
          >>> eroded = binary_erosion(mask, structure)
          >>> eroded
          array([ True,  True, False, False,  True, False, False, False, False,
                 False, False, False, False, False, False,  True,  True,  True,
                 False, False])
          
        3. 我们设法删除了单个 True 值,但我们需要获取其他组的初始配置。为此,我们使用binary dilation 和相同的structure

          >>> dilated = binary_dilation(eroded, structure)
          >>> dilated
          array([ True,  True,  True, False,  True,  True, False, False, False,
                 False, False, False, False, False, False,  True,  True,  True,
                  True, False])
          

          binary_dilation 的文档:link

        4. 最后一步,我们用scipy.ndimage.label标记每个组:

          >>> labeled_array, labels_count = label(dilated)
          >>> labeled_array
          array([1, 1, 1, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 0])
          >>> labels_count
          3
          

          您可以看到labels_countc 值相同 - 问题中的组数。 从这里您可以简单地通过布尔索引获取子组:

          >>> a[labeled_array == 1]
          array([1, 4, 2])
          >>> a[labeled_array == 3]
          array([3, 4, 4, 5])
          

        【讨论】:

          猜你喜欢
          • 1970-01-01
          • 2019-11-14
          • 1970-01-01
          • 1970-01-01
          • 2020-06-06
          • 2011-11-13
          • 1970-01-01
          • 2021-02-15
          • 2019-04-08
          相关资源
          最近更新 更多