【问题标题】:numpy get top k elements from last dimension of ndarraynumpy 从 ndarray 的最后一个维度获取前 k 个元素
【发布时间】:2018-08-07 21:59:53
【问题描述】:

我有一个多维数组,我需要从最后一维的每一行中获取前 k 个元素。

>>> x = np.random.random_integers(0, 100, size=(2,1,1,5))
>>> x
array([[[[99, 39, 10, 18, 68]]],
       [[[22,  3, 13, 56,  2]]]])

我想得到:

array([[[[ 99.,  68.]]],
       [[[ 18.,  99.]]]])

我可以使用以下方法获取索引,但我不确定如何切出值。

>>> k = 2
>>> parts = np.flip(-1 - np.arange(k), 0)
>>> indices = np.flip(
...     np.argpartition(x, parts, axis=-1)[..., -k:],
...     axis=-1)
>>> indices
array([[[[0, 4]]],
       [[[3, 0]]]])

【问题讨论】:

    标签: numpy multidimensional-array numpy-slicing


    【解决方案1】:

    这可以解决你的问题。

    np.sort(x, axis=len(x.shape)-1)[...,-2:]
    

    【讨论】:

      【解决方案2】:
      np.partition(x, 2)[..., -2:]
      

      从每行返回 2 个最大的元素。例如,

      x = np.random.random_integers(0, 100, size=(2,1,1,5))
      print(x)
      print(np.partition(x, 2)[..., -2:])
      

      打印类似的东西

      [[[[79 34 90 80 56]]]
       [[[78 11 24 20 42]]]]
      
      [[[[80 90]]]
       [[[78 42]]]]
      

      【讨论】:

        猜你喜欢
        • 2016-06-30
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2016-01-01
        • 2017-03-03
        • 2016-07-23
        相关资源
        最近更新 更多