【问题标题】:get top n values for each row in a matrix with numpy使用numpy获取矩阵中每一行的前n个值
【发布时间】:2021-12-27 01:00:40
【问题描述】:

假设你有以下矩阵

import numpy as np 
a = np.array([[10,50,30],[60,20,40],[15,30,90]])

您可以通过如下方式获取每行中最高值的索引

r,c = np.unravel_index(a.argmax(axis=1)+ np.arange(0,a.shape[1]*a.shape[0],a.shape[1]), a.shape)
print(a[r[0],c[0]])
print(a[r[1],c[1]])
print(a[r[2],c[2]])

如何获取每行顶部 n 值的索引?

【问题讨论】:

  • 先按行轴排序,然后是前n列

标签: python-3.x numpy numpy-ndarray


【解决方案1】:

对行排序并选择最后n列:

In [315]: a = np.array([[10,50,30],[60,20,40],[15,30,90]])
In [316]: a
Out[316]: 
array([[10, 50, 30],
       [60, 20, 40],
       [15, 30, 90]])
In [317]: np.sort(a, axis=1)
Out[317]: 
array([[10, 30, 50],
       [20, 40, 60],
       [15, 30, 90]])
In [319]: np.sort(a, axis=1)[:,-2:]
Out[319]: 
array([[30, 50],
       [40, 60],
       [30, 90]])

您的最大表达式可以简化为:

In [324]: a.argmax(axis=1)+ np.arange(0,a.shape[1]*a.shape[0],a.shape[1])
Out[324]: array([1, 3, 8])
In [326]: a.flat[_]
Out[326]: array([50, 60, 90])

【讨论】:

  • 我的问题不够清楚。我需要找到前 n 个值的索引。我已经更新了我的问题。
  • 你试过argsort同样的方法吗?
  • 谢谢,这正是我需要的!
  • 您能否更新您的答案以便我查看?
【解决方案2】:

您应该可以只使用np.argpartition 并设置axis=1

import numpy as np 
a = np.array([
              [10,50,30],
              [60,20,40],
              [15,30,90]
              ])
n = 2
indices = np.argpartition(a, -n, axis=1)[:, -n:]
print(indices)
print('\n', np.sort(indices))
[[2 1]
 [2 0]
 [1 2]]

 [[1 2]
 [0 2]
 [1 2]]

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2023-03-21
    • 1970-01-01
    • 2022-06-30
    • 1970-01-01
    相关资源
    最近更新 更多