【问题标题】:Find indices of elements equal to zero in a PyTorch Tensor在 PyTorch 张量中查找等于零的元素索引
【发布时间】:2020-07-04 03:13:59
【问题描述】:

我的问题与this 的问题几乎相同,但在 PyTorch 中存在显着差异。我不希望使用 Numpy 解决方案,因为这将涉及将数据移回 CPU。我看到,与 Numpy 一样,PyTorch 有一个 nonzero 函数,但是它的 where 函数(我链接的 Numpy 线程中的解决方案)具有与 Numpy 不同的行为。 我想要的行为是is_zero() 函数,如下所示:

>>> arr.nonzero()
tensor([[0, 1],
        [1, 0]])  
>>> arr.is_zero()
tensor([[0, 0],
        [1, 1]])

【问题讨论】:

    标签: python pytorch


    【解决方案1】:

    你可以做一个布尔掩码然后调用nonzero():

    (arr == 0).nonzero()
    

    例如:

    arr = torch.randint(high=2, size=(3, 3))
    
    tensor([[1, 1, 0],  # (0, 2)
            [1, 1, 0],  # (1, 2)
            [1, 0, 0]]) # (2, 1) and (2, 2)
    
    (arr == 0).nonzero()
    
    tensor([[0, 2],
            [1, 2],
            [2, 1],
            [2, 2]])
    

    【讨论】:

    • 谢谢!这正是我所需要的。
    猜你喜欢
    • 1970-01-01
    • 2021-09-14
    • 2011-06-03
    • 1970-01-01
    • 2019-06-01
    • 2019-09-28
    • 2022-11-24
    • 2021-06-27
    相关资源
    最近更新 更多