【问题标题】:Pytorch: How can I find indices of first nonzero element in each row of a 2D tensor?Pytorch:如何在二维张量的每一行中找到第一个非零元素的索引?
【发布时间】:2019-09-28 23:42:52
【问题描述】:

我有一个二维张量,每一行都有一些非零元素,如下所示:

import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)

我想要一个张量,其中包含每行中第一个非零元素的索引:

indices = tensor([2],
                 [3])

如何在 Pytorch 中计算?

【问题讨论】:

    标签: python machine-learning pytorch


    【解决方案1】:

    我简化了 Iman 的方法来执行以下操作:

    idx = torch.arange(tmp.shape[1], 0, -1)
    tmp2= tmp * idx
    indices = torch.argmax(tmp2, 1, keepdim=True)
    

    【讨论】:

    • 这是一个简单而聪明的解决方案!
    • 请注意,如果没有非零,则默认为零作为第一个非零索引。
    【解决方案2】:

    我可以为我的问题找到一个棘手的答案:

      tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                         [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
      idx = reversed(torch.Tensor(range(1,8)))
      print(idx)
    
      tmp2= torch.einsum("ab,b->ab", (tmp, idx))
    
      print(tmp2)
    
      indices = torch.argmax(tmp2, 1, keepdim=True)
      print(indeces)
    

    结果是:

    tensor([7., 6., 5., 4., 3., 2., 1.])
    tensor([[0., 0., 5., 0., 3., 0., 0.],
           [0., 0., 0., 4., 3., 0., 0.]])
    tensor([[2],
            [3]])
    

    【讨论】:

      【解决方案3】:

      所有非零值都相等,所以argmax 返回第一个索引。

      tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                          [0, 0, 0, 1, 1, 0, 0]])
      indices = tmp.argmax(1)
      

      【讨论】:

        猜你喜欢
        • 2019-09-15
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2020-07-04
        • 2020-09-26
        • 1970-01-01
        • 2022-11-14
        • 2022-10-07
        相关资源
        最近更新 更多