【问题标题】:Find indexes of matching rows in two 2-D arrays在两个二维数组中查找匹配行的索引
【发布时间】:2013-11-26 23:47:35
【问题描述】:

假设我有如下两个二维数组:

array([[3, 3, 1, 0],
       [2, 3, 1, 3],
       [0, 2, 3, 1],
       [1, 0, 2, 3],
       [3, 1, 0, 2]], dtype=int8)

array([[0, 3, 3, 1],
       [0, 2, 3, 1],
       [1, 0, 2, 3],
       [3, 1, 0, 2],
       [3, 3, 1, 0]], dtype=int8)

每个数组中的某些行在另一个数组中具有按值(但不一定按索引)匹配的对应行,而有些则没有。

我想找到一种有效的方法来返回两个数组中对应于匹配行的索引对。如果它们是元组,我希望返回

(0,4)
(2,1)
(3,2)
(4,3)

【问题讨论】:

    标签: python numpy


    【解决方案1】:

    我想不出一个具体的 numpy 方法来做到这一点,但这是我对常规列表的处理方式:

    >>> L1= [[3, 3, 1, 0],
    ...        [2, 3, 1, 3],
    ...        [0, 2, 3, 1],
    ...        [1, 0, 2, 3],
    ...        [3, 1, 0, 2]]
    >>> L2 = [[0, 3, 3, 1],
    ...        [0, 2, 3, 1],
    ...        [1, 0, 2, 3],
    ...        [3, 1, 0, 2],
    ...        [3, 3, 1, 0]]
    >>> L1 = {tuple(row):i for i,row in enumerate(L1)}
    >>> answer = []
    >>> for i,row in enumerate(L2):
    ...   if tuple(row) in L1:
    ...     answer.append((L1[tuple(row)], i))
    ... 
    >>> answer
    [(2, 1), (3, 2), (4, 3), (0, 4)]
    

    【讨论】:

    • O(n)!好的。但是没有一种 numpy 的方法可以做到吗?
    • @slider: I can't think of a numpy way to do it,主要是因为我不经常使用 numpy(它在我的待办事项列表上的时间比我自豪承认的要长)
    • 这是否可以概括为L2只有一行的情况,我们希望获得L1中匹配行的“行索引”,L1中的行不一定独一无二?
    【解决方案2】:

    这是一个全部为numpy 的解决方案 - 不一定比迭代 Python 解决方案更好。它仍然需要查看所有组合。

    In [53]: np.array(np.all((x[:,None,:]==y[None,:,:]),axis=-1).nonzero()).T.tolist()
    Out[53]: [[0, 4], [2, 1], [3, 2], [4, 3]]
    

    中间数组是(5,5,4)np.all 将其简化为:

    array([[False, False, False, False,  True],
           [False, False, False, False, False],
           [False,  True, False, False, False],
           [False, False,  True, False, False],
           [False, False, False,  True, False]], dtype=bool)
    

    剩下的只是提取True的索引

    在粗略测试中,这次是 47.8 us; L1 字典的另一个答案是 38.3 us;第三个在 496 us 处有双循环。

    【讨论】:

      【解决方案3】:

      您可以使用 void 数据类型技巧在两个数组的行上使用一维函数。 a_viewb_view 是一维向量,每个条目代表一整行。然后,我选择对一个数组进行排序并使用np.searchsorted 在该数组中查找另一个数组的项目。如果我们排序的数组长度为m,另一个长度为n,排序需要时间m * log(m),而np.searchsorted 进行的二分查找需要时间n * log(m),总共需要时间(n + m) * log(m)。因此,您希望对两个数组中最短的一个进行排序:

      def find_rows(a, b):
          dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
      
          a_view = np.ascontiguousarray(a).view(dt).ravel()
          b_view = np.ascontiguousarray(b).view(dt).ravel()
      
          sort_b = np.argsort(b_view)
          where_in_b = np.searchsorted(b_view, a_view,
                                       sorter=sort_b)
          where_in_b = np.take(sort_b, where_in_b)
          which_in_a = np.take(b_view, where_in_b) == a_view
          where_in_b = where_in_b[which_in_a]
          which_in_a = np.nonzero(which_in_a)[0]
          return np.column_stack((which_in_a, where_in_b))
      

      使用 ab 您的两个示例数组:

      In [14]: find_rows(a, b)
      Out[14]: 
      array([[0, 4],
             [2, 1],
             [3, 2],
             [4, 3]], dtype=int64)
      
      In [15]: %timeit find_rows(a, b)
      10000 loops, best of 3: 29.7 us per loop
      

      在我的系统上,对于您的测试数据,字典方法的时钟速度大约为 22 us,但对于 1000x4 的数组,这种 numpy 方法比纯 Python 方法快大约 6 倍(483 us vs 2.54 ms)。

      【讨论】:

      • 这太棒了。我花了整整一个小时才弄清楚你在做什么。虽然有一个小错误,因为 searchsorted 有可能返回应该在末尾插入的项目,这会导致索引超出范围错误。
      • 例如,只需将 a 数组的最后一行更改为 [3,3,3,3] 即可得到IndexError: index 5 is out of bounds for size 5
      • 这确实加快了我的代码速度,非常感谢。在行上使用 dict() 无法将其剪切为 10^4 或更多行。
      猜你喜欢
      • 2021-03-03
      • 2014-11-07
      • 1970-01-01
      • 2018-02-22
      • 1970-01-01
      • 2019-07-24
      • 2017-06-18
      • 2019-07-05
      • 2019-08-31
      相关资源
      最近更新 更多