【问题标题】:For a given condition, get indices of values in 2D tensor A, use those to index a 3D tensor B对于给定条件,获取 2D 张量 A 中的值的索引,使用这些索引来索引 3D 张量 B
【发布时间】:2020-01-27 21:37:07
【问题描述】:

对于给定的二维张量,我想检索值为1 的所有索引。我希望能够简单地使用torch.nonzero(a == 1).squeeze(),这将返回tensor([1, 3, 2])。然而,相反,torch.nonzero(a == 1) 返回一个二维张量(没关系),每行有两个值(这不是我所期望的)。然后应使用返回的索引来索引 3D 张量的第二维(索引 1),再次返回 2D 张量。

import torch

a = torch.Tensor([[12, 1, 0, 0],
                  [4, 9, 21, 1],
                  [10, 2, 1, 0]])

b = torch.rand(3, 4, 8)

print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])

idxs = torch.nonzero(a == 1)
print('idxs_size', idxs.size())
# idxs_size torch.Size([3, 2])

print(b.gather(1, idxs))

显然,这不起作用,导致一个RunTimeError:

RuntimeError: invalid argument 4: 索引张量必须相同 尺寸作为输入张量 C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensorEvenMoreMath.cpp:453

看来idxs不是我想的那样,也不能像我想的那样使用。 idxs

tensor([[0, 1],
        [1, 3],
        [2, 2]])

但是通读documentation 我不明白为什么我还要取回结果张量中的行索引。现在,我知道我可以通过切片idxs[:, 1] 来获得正确的 idx,但是我仍然不能将这些值用作 3D 张量的索引,因为会引发与以前相同的错误。是否可以使用索引的一维张量来选择给定维度的项目?

【问题讨论】:

    标签: python multidimensional-array pytorch tensor tensor-indexing


    【解决方案1】:
    import torch
    
    a = torch.Tensor([[12, 1, 0, 0],
                      [4, 9, 21, 1],
                      [10, 2, 1, 0]])
    
    b = torch.rand(3, 4, 8)
    
    print('a_size', a.size())
    # a_size torch.Size([3, 4])
    print('b_size', b.size())
    # b_size torch.Size([3, 4, 8])
    
    #idxs = torch.nonzero(a == 1, as_tuple=True)
    idxs = torch.nonzero(a == 1)
    #print('idxs_size', idxs.size())
    
    print(torch.index_select(b,1,idxs[:,1]))
    

    【讨论】:

    • 不,一点也不。首先z 是一个元组,其次我的最终目标是从b 获取索引。
    • 是的,对不起,我完全没有正确阅读您的问题......对于您的最后一个问题“是否可以使用索引的一维张量来选择给定维度的项目”,确实火炬。 index_select(b,1,idxs[:,1]) 给你你需要的?
    • 没有。这将返回一个 3D 矩阵。我期待一个2D的。看我的回答。
    • 啊,明白了。很高兴你想出来了。旁注只需确保连续没有 2 个 1
    【解决方案2】:

    假设b的三个维度为batch_size x sequence_length x features(b x s x feats),可以达到预期的结果如下。

    import torch
    
    a = torch.Tensor([[12, 1, 0, 0],
                      [4, 9, 21, 1],
                      [10, 2, 1, 0]])
    
    b = torch.rand(3, 4, 8)
    print(b.size())
    # b x s x feats
    idxs = torch.nonzero(a == 1)[:, 1]
    print(idxs.size())
    # b
    c = b[torch.arange(b.size(0)), idxs]
    print(c.size())
    # b x feats
    

    【讨论】:

      【解决方案3】:

      您可以简单地将它们切片并将其作为索引传递,如下所示:

      In [193]: idxs = torch.nonzero(a == 1)     
      In [194]: c = b[idxs[:, 0], idxs[:, 1]]  
      
      In [195]: c   
      Out[195]: 
      tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
              [0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
              [0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
      

      或者,一个更简单且我更喜欢的方法是只使用torch.where(),然后直接索引到张量b,如下所示:

      In [196]: b[torch.where(a == 1)]  
      Out[196]: 
      tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
              [0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
              [0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
      

      关于上述使用torch.where()的方法的更多解释:它基于advanced indexing的概念工作。也就是说,当我们使用序列对象的元组(例如张量元组、列表元组、元组元组等)对张量进行索引时。

      # some input tensor
      In [207]: a  
      Out[207]: 
      tensor([[12.,  1.,  0.,  0.],
              [ 4.,  9., 21.,  1.],
              [10.,  2.,  1.,  0.]])
      

      对于基本切片,我们需要一个整数索引元组:

         In [212]: a[(1, 2)] 
         Out[212]: tensor(21.)
      

      为了使用高级索引实现同样的目的,我们需要一个序列对象的元组:

      # adv. indexing using a tuple of lists
      In [213]: a[([1,], [2,])] 
      Out[213]: tensor([21.])
      
      # adv. indexing using a tuple of tuples
      In [215]: a[((1,), (2,))]  
      Out[215]: tensor([21.])
      
      # adv. indexing using a tuple of tensors
      In [214]: a[(torch.tensor([1,]), torch.tensor([2,]))] 
      Out[214]: tensor([21.])
      

      并且返回张量的维度总是比输入张量的维度小一维。

      【讨论】:

      • 我非常喜欢您第二个建议的简单性。你能解释一下为什么这有效吗?因为torch.where(a == 1) 返回一个元组。用元组对张量进行切片是如何工作的?
      • @BramVanroy 添加了一些解释 :)
      【解决方案4】:

      作为@kmario23的解决方案的补充,你仍然可以达到同样的结果

      b[torch.nonzero(a==1,as_tuple=True)]
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 2021-01-05
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2022-10-23
        • 1970-01-01
        • 2018-03-13
        • 2017-05-04
        相关资源
        最近更新 更多