【问题标题】:Find indices of values greater than a threshold by row in a numpy 2darray在 numpy 二维数组中逐行查找大于阈值的值的索引
【发布时间】:2019-06-21 08:04:30
【问题描述】:

我有一个 2darray 如下。我想按数组中的每一行查找高于阈值(例如 0.7)的值的索引。

items= np.array([[1.        , 0.40824829, 0.03210806, 0.29488391, 0.        ,
        0.5       , 0.32444284, 0.57735027, 0.        , 0.5       ],
       [0.40824829, 1.        , 0.57675476, 0.48154341, 0.        ,
        0.81649658, 0.79471941, 0.70710678, 0.57735027, 0.40824829],
       [0.03210806, 0.57675476, 1.        , 0.42606683, 0.        ,
        0.        , 0.92713363, 0.834192  , 0.        , 0.73848549],
       [0.29488391, 0.48154341, 0.42606683, 1.        , 0.        ,
        0.29488391, 0.52620136, 0.51075392, 0.20851441, 0.44232587],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ],
       [0.5       , 0.81649658, 0.        , 0.29488391, 0.        ,
        1.        , 0.32444284, 0.28867513, 0.70710678, 0.        ],
       [0.32444284, 0.79471941, 0.92713363, 0.52620136, 0.        ,
        0.32444284, 1.        , 0.93658581, 0.22941573, 0.81110711],
       [0.57735027, 0.70710678, 0.834192  , 0.51075392, 0.        ,
        0.28867513, 0.93658581, 1.        , 0.        , 0.8660254 ],
       [0.        , 0.57735027, 0.        , 0.20851441, 0.        ,
        0.70710678, 0.22941573, 0.        , 1.        , 0.        ],
       [0.5       , 0.40824829, 0.73848549, 0.44232587, 0.        ,
        0.        , 0.81110711, 0.8660254 , 0.        , 1.        ]])

indices_items = np.argwhere(items>= 0.7)

这个(indices_items)返回

array([[0, 0],
       [1, 1],
       [1, 5],
       [1, 6],
       [1, 7],
       [2, 2],
       [2, 6],
       [2, 7],
       [2, 9],
       [3, 3],
       [5, 1],
       [5, 5],
       [5, 8],
       [6, 1],
       [6, 2],
       [6, 6],
       [6, 7],
       [6, 9],
       [7, 1],
       [7, 2],
       [7, 6],
       [7, 7],
       [7, 9],
       [8, 5],
       [8, 8],
       [9, 2],
       [9, 6],
       [9, 7],
       [9, 9]], dtype=int64)

如何按如下方式获取索引? row0 -> [0] row1-> [0,1,5,6,7] row2-> [2,6,7,9] row3-> [3] row4-> [] #这应该是空列表,因为没有超过阈值的值...

【问题讨论】:

  • 您将使用该输出做什么?您将无法将其保留为ndarray,因为它将不再是矩形。所以它不会是真正的行列形式,而是更多的链表。

标签: python numpy threshold numpy-ndarray


【解决方案1】:

这在性能方面可能不是最佳的,但如果你不关心它应该没问题。

indices_items = []
for l in items:
    indices_items.append(np.argwhere(l >= 0.7).flatten().tolist())

indices_items
Out[5]: 
[[0],
[1, 5, 6, 7],
[2, 6, 7, 9],
[3],
[],
[1, 5, 8],
[1, 2, 6, 7, 9],
[1, 2, 6, 7, 9],
[5, 8],
[2, 6, 7, 9]]

【讨论】:

    【解决方案2】:

    使用np.where 获取行col,然后使用np.searchsorted 获取row-array 上的间隔索引并使用它们拆分col-array -

    In [38]: r,c = np.where(items>= 0.7)
    
    In [39]: np.split(c,np.searchsorted(r,range(1,items.shape[0])))
    Out[39]: 
    [array([0], dtype=int64),
     array([1, 5, 6, 7], dtype=int64),
     array([2, 6, 7, 9], dtype=int64),
     array([3], dtype=int64),
     array([], dtype=int64),
     array([1, 5, 8], dtype=int64),
     array([1, 2, 6, 7, 9], dtype=int64),
     array([1, 2, 6, 7, 9], dtype=int64),
     array([5, 8], dtype=int64),
     array([2, 6, 7, 9], dtype=int64)]
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2018-10-07
      • 2011-11-01
      • 1970-01-01
      • 2017-05-21
      • 2019-07-05
      • 2021-03-03
      • 1970-01-01
      相关资源
      最近更新 更多