【问题标题】:How to find all elements in a numpy 2-dimensional array that match a certain list?如何在 numpy 二维数组中找到与某个列表匹配的所有元素?
【发布时间】:2016-04-28 23:15:05
【问题描述】:

我有一个二维 NumPy 数组,例如:

array([[1, 1, 0, 2, 2],
       [1, 1, 0, 2, 0],
       [0, 0, 0, 0, 0],
       [3, 3, 0, 4, 4],
       [3, 3, 0, 4, 4]])

我想从该数组中获取某个列表中的所有元素,例如 (1, 3, 4)。示例案例中的预期结果是:

array([[1, 1, 0, 0, 0],
       [1, 1, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [3, 3, 0, 4, 4],
       [3, 3, 0, 4, 4]])

我知道我可以做到(这里推荐Numpy: find elements within range):

np.logical_or(
    np.logical_or(cc_labeled == 1, cc_labeled == 3),
    cc_labeled == 4
)

,但这只会在示例情况下合理有效。实际上,迭代使用 for 循环和 numpy.logical_or 变得非常慢,因为可能值的列表以千为单位(并且 numpy 数组的维度大约为 1000 x 1000)。

【问题讨论】:

    标签: python arrays performance numpy vectorization


    【解决方案1】:

    你可以使用np.in1d -

    A*np.in1d(A,[1,3,4]).reshape(A.shape)
    

    另外,np.where 也可以使用 -

    np.where(np.in1d(A,[1,3,4]).reshape(A.shape),A,0)
    

    您还可以使用np.searchsorted 来查找此类匹配项,方法是使用其可选的'side' 参数,输入为leftright,并注意对于匹配项,searchsorted 将使用这两个输入输出不同的结果。因此,np.in1d(A,[1,3,4]) 的等价物将是 -

    M = np.searchsorted([1,3,4],A.ravel(),'left') != \
        np.searchsorted([1,3,4],A.ravel(),'right')
    

    因此,最终输出将是 -

    out = A*M.reshape(A.shape)
    

    请注意,如果输入搜索列表未排序,则需要在np.searchsorted 中使用可选参数sorter 及其argsort 索引。

    示例运行 -

    In [321]: A
    Out[321]: 
    array([[1, 1, 0, 2, 2],
           [1, 1, 0, 2, 0],
           [0, 0, 0, 0, 0],
           [3, 3, 0, 4, 4],
           [3, 3, 0, 4, 4]])
    
    In [322]: A*np.in1d(A,[1,3,4]).reshape(A.shape)
    Out[322]: 
    array([[1, 1, 0, 0, 0],
           [1, 1, 0, 0, 0],
           [0, 0, 0, 0, 0],
           [3, 3, 0, 4, 4],
           [3, 3, 0, 4, 4]])
    
    In [323]: np.where(np.in1d(A,[1,3,4]).reshape(A.shape),A,0)
    Out[323]: 
    array([[1, 1, 0, 0, 0],
           [1, 1, 0, 0, 0],
           [0, 0, 0, 0, 0],
           [3, 3, 0, 4, 4],
           [3, 3, 0, 4, 4]])
    
    In [324]: M = np.searchsorted([1,3,4],A.ravel(),'left') != \
         ...:     np.searchsorted([1,3,4],A.ravel(),'right')
         ...: A*M.reshape(A.shape)
         ...: 
    Out[324]: 
    array([[1, 1, 0, 0, 0],
           [1, 1, 0, 0, 0],
           [0, 0, 0, 0, 0],
           [3, 3, 0, 4, 4],
           [3, 3, 0, 4, 4]])
    

    运行时测试和验证输出 -

    In [309]: # Inputs
         ...: A = np.random.randint(0,1000,(400,500))
         ...: lst = np.sort(np.random.randint(0,1000,(100))).tolist()
         ...: 
         ...: def func1(A,lst):                         
         ...:   return A*np.in1d(A,lst).reshape(A.shape)
         ...: 
         ...: def func2(A,lst):                         
         ...:   return np.where(np.in1d(A,lst).reshape(A.shape),A,0)
         ...: 
         ...: def func3(A,lst):                         
         ...:   mask = np.searchsorted(lst,A.ravel(),'left') != \
         ...:          np.searchsorted(lst,A.ravel(),'right')
         ...:   return A*mask.reshape(A.shape)
         ...: 
    
    In [310]: np.allclose(func1(A,lst),func2(A,lst))
    Out[310]: True
    
    In [311]: np.allclose(func1(A,lst),func3(A,lst))
    Out[311]: True
    
    In [312]: %timeit func1(A,lst)
    10 loops, best of 3: 30.9 ms per loop
    
    In [313]: %timeit func2(A,lst)
    10 loops, best of 3: 30.9 ms per loop
    
    In [314]: %timeit func3(A,lst)
    10 loops, best of 3: 28.6 ms per loop
    

    【讨论】:

    • 我选择了 np.where(...) 变体,因为它最直观易懂。谢谢!
    【解决方案2】:

    使用np.in1d:

    np.in1d(arr, [1,3,4]).reshape(arr.shape)
    

    in1d,顾名思义,是对展平的数组进行操作,所以操作后需要reshape。

    【讨论】:

      猜你喜欢
      • 2011-12-22
      • 2021-12-19
      • 2021-01-11
      • 2015-05-23
      • 2014-11-07
      • 1970-01-01
      • 2021-04-26
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多