【问题标题】:Getting the count of items under diagonal in numpy在numpy中获取对角线下的项目数
【发布时间】:2019-09-20 07:14:41
【问题描述】:

我有一个相关矩阵,我想计算对角线以下的项目数。最好在 numpy 中。

[[1,   0,   0,   0,  0], 
[.35,  1,   0,   0,  0], 
[.42, .31,  1,   0,  0], 
[.25, .38, .41,  1,  0], 
[.21, .36, .46, .31, 1]]

我希望它返回 10。或者,返回对角线下所有数字的平均值。

【问题讨论】:

  • 您能否检查一下 numpy.trace 并告诉我是否有帮助。

标签: numpy mean diagonal


【解决方案1】:

设置

a = np.array([[1.  , 0.  , 0.  , 0.  , 0.  ],
              [0.35, 1.  , 0.  , 0.  , 0.  ],
              [0.42, 0.31, 1.  , 0.  , 0.  ],
              [0.25, 0.38, 0.41, 1.  , 0.  ],
              [0.21, 0.36, 0.46, 0.31, 1.  ]])

numpy.tril_indices 将给出对角线下所有元素的索引(如果您提供-1 的偏移量),从那里开始,它变得像索引和调用meansize 一样简单


n, m = a.shape

m = np.tril_indices(n=n, k=-1, m=m)

a[m]
# array([0.35, 0.42, 0.31, 0.25, 0.38, 0.41, 0.21, 0.36, 0.46, 0.31])

a[m].mean()
# 0.346

a[m].size
# 10

【讨论】:

    【解决方案2】:

    一个更原始和庞大的答案,因为numpy 提供np.tril_indices,正如 user3483203 提到的那样,但是你想要的每行迭代 i 如下(就 [row,col] 索引而言):

                          (i=0)
    [1,0]                 (i=1) 
    [2,0] [2,1]           (i=2)
    [3,0] [3,1] [3,2]     (i=3)
    ...
    

    这本质上是列表[i,i,i,...] = [i]*izip(i 的重复)和[0,1,...,i-1] = range(i)。因此,遍历表的行,您实际上可以获得每次迭代的索引并执行您选择的运算符。

    示例设置:

    test = np.array(
    [[1,   0,   0,   0,  0], 
    [.35,  1,   0,   0,  0], 
    [.42, .31,  1,   0,  0], 
    [.25, .38, .41,  1,  0], 
    [.21, .36, .46, .31, 1]])
    

    函数定义:

    def countdiag(myarray):
        numvals = 0
        totsum = 0
    
    
        for i in range(myarray.shape[0]): # row iteration
            colc = np.array(range(i)) # calculate column indices
            rowc = np.array([i]*i) # calculate row indices
    
            if any(rowc):
                print(np.sum(myarray[rowc,colc]))
                print(len(myarray[rowc,colc]))
    
                numvals += len(myarray[rowc,colc])
                totsum += np.sum(myarray[rowc,colc])
    
    
            print(list(zip([i]*i, np.arange(i))))
    
        mean = totsum / numvals
    
        return mean, numvals
    

    测试:

     [165]: countdiag(test)
    
    
    
    []
    0.35
    1
    [(1, 0)]
    0.73
    2
    [(2, 0), (2, 1)]
    1.04
    3
    [(3, 0), (3, 1), (3, 2)]
    1.34
    4
    [(4, 0), (4, 1), (4, 2), (4, 3)]
    0.346
    Out[165]:
    (0.346, 10)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2019-10-20
      • 2022-01-26
      • 2012-06-05
      • 2018-10-03
      • 2015-05-09
      • 2021-12-21
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多