【问题标题】:Mask from max values in numpy array, specific axis从 numpy 数组中的最大值掩码,特定轴
【发布时间】:2017-12-06 15:47:03
【问题描述】:

输入示例:

我有一个 numpy 数组,例如

a=np.array([[0,1], [2, 1], [4, 8]])

期望的输出:

我想生成一个掩码数组,其最大值沿给定轴(在我的情况下为轴 1)为 True,其他所有轴为 False。例如在这种情况下

mask = np.array([[False, True], [True, False], [False, True]])

尝试:

我尝试过使用np.amax 的方法,但这会返回扁平列表中的最大值:

>>> np.amax(a, axis=1)
array([1, 2, 8])

np.argmax 类似地返回沿该轴的最大值的索引。

>>> np.argmax(a, axis=1)
array([1, 0, 1])

我可以以某种方式对其进行迭代,但是一旦这些数组变得更大,我希望解决方案在 numpy.

【问题讨论】:

    标签: python numpy


    【解决方案1】:

    方法#1

    使用broadcasting,我们可以使用与最大值进行比较,同时保持暗淡以方便broadcasting -

    a.max(axis=1,keepdims=1) == a
    

    示例运行 -

    In [83]: a
    Out[83]: 
    array([[0, 1],
           [2, 1],
           [4, 8]])
    
    In [84]: a.max(axis=1,keepdims=1) == a
    Out[84]: 
    array([[False,  True],
           [ True, False],
           [False,  True]], dtype=bool)
    

    方法#2

    或者使用argmax 索引,以针对沿列的索引范围的broadcasted-comparison 的另一种情况-

    In [92]: a.argmax(axis=1)[:,None] == range(a.shape[1])
    Out[92]: 
    array([[False,  True],
           [ True, False],
           [False,  True]], dtype=bool)
    

    方法#3

    要完成设置,如果我们正在寻找性能,请使用初始化,然后使用 advanced-indexing -

    out = np.zeros(a.shape, dtype=bool)
    out[np.arange(len(a)), a.argmax(axis=1)] = 1
    

    【讨论】:

    • 我喜欢第二个,(np.equal.outer(a.argmax(axis=1), range(a.shape[1])))
    • 我最终使用了方法 #2 的变体。当 axis=1 数组包含多个对应于最大值的项目时,它也有效。 argmax 仅将第一次出现返回为 True,这是最适合我的问题的行为
    【解决方案2】:

    创建一个单位矩阵并在您的数组上使用 argmax 从其行中进行选择:

    np.identity(a.shape[1], bool)[a.argmax(axis=1)]
    # array([[False,  True],
    #        [ True, False],
    #        [False,  True]], dtype=bool)
    

    请注意,这忽略了关系,它只与argmax 返回的值一致。

    【讨论】:

    • np.eye 是一个较短的替代方案,但需要dtype=
    • @Divakar (a.max(1)==a.T).T 在这方面看起来很难被击败。
    【解决方案3】:

    你已经回答了一半。一旦沿轴计算 max,您可以将其与输入数组进行比较,您将获得所需的二进制掩码!

    In [7]: maxx = np.amax(a, axis=1)
    
    In [8]: maxx
    Out[8]: array([1, 2, 8])
    
    In [12]: a >= maxx[:, None]
    Out[12]: 
    array([[False,  True],
           [ True, False],
           [False,  True]], dtype=bool)
    

    注意:在比较amaxx时使用NumPy broadcasting

    【讨论】:

      【解决方案4】:

      在线:np.equal(a.max(1)[:,None],a)np.equal(a.max(1),a.T).T

      但这可能会导致连续出现多个。

      【讨论】:

        【解决方案5】:

        在多维情况下,您还可以使用np.indices。假设你有一个数组:

        a = np.array([[
            [0, 1, 2],
            [3, 8, 5],
            [6, 7, -1],
            [9, 5, 8]],[
            [5, 2, 8],
            [7, 6, -3],
            [-1, 2, 1],
            [3, 5, 6]]
        ])
        

        您可以像这样访问为轴 0 计算的 argmax 值:

        k = np.zeros((2, 4, 3), np.bool)
        k[a.argmax(0), ind[0], ind[1]] = 1
        

        输出将是:

        array([[[False, False, False],
                [False,  True,  True],
                [ True,  True, False],
                [ True,  True,  True]],
        
               [[ True,  True,  True],
                [ True, False, False],
                [False, False,  True],
                [False, False, False]]])
        

        【讨论】:

          猜你喜欢
          • 2019-06-08
          • 1970-01-01
          • 2022-01-02
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 1970-01-01
          • 2019-06-25
          • 1970-01-01
          相关资源
          最近更新 更多