【问题标题】:How can I efficiently find the maximum element and its indices in a scipy csr_matrix?如何在 scipy csr_matrix 中有效地找到最大元素及其索引?
【发布时间】:2017-03-31 10:49:00
【问题描述】:

我有一个稀疏的无向图存储在scipy csr_matrix中,我需要找出权重最大的边,这意味着我需要找到最大值及其对应的行和列索引(实际上我必须找到K 个最大值,但为了简化问题)。因此我写道:

M=M.toarray()
for i in range(1,len(M)):
    for j in range(i+1,len(M[i])):
        if M[i][j] > maximum:
            row,col,maximum = i,j,M[i][j]

它似乎很笨拙,表现不佳。有没有更好的方法来做到这一点?

【问题讨论】:

标签: python matrix scipy


【解决方案1】:
import numpy as np
np.array(np.unravel_index(np.argsort(M.flatten(), axis=0)[-K:], M.shape)).T[::-1]

这将按降序返回 K 个最大值的索引数组。 示例:

M=np.ones((10,10))
M[2,6]=10
M[4,3]=12
np.array(np.unravel_index(np.argsort(M.flatten(), axis=0)[-2:], M.shape)).T[::-1]

输出:[[4,3], [2,6]]

编辑:

(假设你有一个方阵)。 如果您不希望返回任何对角线元素,您可以屏蔽数组:

mask = (1-np.identity(M.shape[0]))
np.array(np.unravel_index(np.argsort((M*mask).flatten(), axis=0)[-2:], M.shape)).T[::-1]

【讨论】:

  • 对于稀疏的M,不能使用M.flatten()。但是M.A.flatten() 会起作用。 np.argmax(M.A) 给出最大值的散列索引。在任何情况下,如果您想要密集数组中的行/列,使用unravel_index 是一个很好的工具。
  • 是的,这是正确的@hpaulj - 或 M.toarray() 就像 OP 一样:-)
  • 您的解决方案运行良好,但出现了另一个问题:我不需要像 M[3,3],M[4,4] 这样的对角线上的元素。在我发布的代码中,我使用 for 循环来摆脱它们。如何使用您的解决方案做到这一点?
【解决方案2】:

如果你想单独找到最大值,M.max() 就足够了:

>>> m = scipy.sparse.rand(1000, 1000, format='csr')
>>> type(m)
<class 'scipy.sparse.csr.csr_matrix'>
>>> m.max()
0.99991127228906729

如果你也想找到索引besides converting to a coo_matrix,你可以直接对.data.indices.indptr进行操作。文档中提到了这些成员之间的关系,

csr_matrix((data, indices, indptr), [shape=(M, N)])

是标准 CSR 表示,其中行 i 的列索引存储在 indices[indptr[i]:indptr[i+1]] 中,它们的相应值存储在 data[indptr[i]:indptr[i+1]] 中。

所以,

>>> m.sort_indices()
>>> numpy.argmax(m.data)
1171
>>> index = _
>>> m.indices[index]
483
>>> col = _
>>> numpy.searchsorted(m.indptr, index, side='right') - 1
116
>>> row = _
>>> m[row, col]
0.99991127228906729

【讨论】:

  • 我试过在.indices.indptr 上工作,只是感到困惑。我不喜欢直接操纵班级成员的方式。谢谢。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2022-01-26
  • 2014-10-22
  • 2012-09-29
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多