【问题标题】:Torch gather middle dimension火炬聚中维度
【发布时间】:2021-03-30 09:39:05
【问题描述】:

a 成为(n, d, l) 张量。让 indices 成为 (n, 1) 张量,包含索引。我想从indices 给出的索引中收集中间维度张量中的a。因此,生成的张量的形状将是 (n, l)

n = 3
d = 2
l = 3

a = tensor([[[ 0,  1,  2],
             [ 3,  4,  5]],

            [[ 6,  7,  8],
             [ 9, 10, 11]],

            [[12, 13, 14],
             [15, 16, 17]]])

indices = tensor([[0],
                  [1],
                  [0]])

# Shape of result is (n, l)
result = tensor([[ 0,  1,  2],  # a[0, 0, :] since indices[0] == 0

                 [ 9, 10, 11],  # a[1, 1, :] since indices[1] == 1

                 [12, 13, 14]]) # a[2, 0, :] since indices[2] == 0

这确实类似于a.gather(1, indices),但gather 不起作用,因为indicesa 的形状不同。如何在此设置中使用gather?或者我应该使用什么?

【问题讨论】:

    标签: python pytorch torch


    【解决方案1】:

    您可以手动创建索引。如果 indices 张量具有示例数据的形状,则必须将其展平。

    a[torch.arange(len(a)),indices.view(-1)]
    # equal to a[[0,1,2],[0,1,0]]
    

    输出:

    tensor([[ 0,  1,  2],
            [ 9, 10, 11],
            [12, 13, 14]])
    

    【讨论】:

      【解决方案2】:

      我将我的答案添加到迈克尔的顶部,以获取索引维度两侧的更多维度,但我希望有人给我一个更好的不使用 arange 的答案!

      def squeeze_index(x, dim, index):
        # flatten to rows
        y = x.view((-1,) + x.shape[dim:])
      
        # generate row indices
        rows = torch.arange(y.shape[0])
      
        # index and reshape
        result_shape = x.shape[:dim] + (x.shape[dim+1:] if dim != -1 else ())
        return y[rows, index.view(-1), ...].view(result_shape)
      
      a = torch.arange(2*3*2*3).reshape((2,3,2,3))
      indices = torch.tensor([0,0,1,0,0,1]).reshape((2,3))
      result = squeeze_index(a, 2, i)
      print("a", a.shape, a)
      print("indices", indices.shape, indices)
      print("result", result.shape, result)
      

      给予:

      a torch.Size([2, 3, 2, 3]) tensor([[[[ 0,  1,  2],
                [ 3,  4,  5]],
      
               [[ 6,  7,  8],
                [ 9, 10, 11]],
      
               [[12, 13, 14],
                [15, 16, 17]]],
      
      
              [[[18, 19, 20],
                [21, 22, 23]],
      
               [[24, 25, 26],
                [27, 28, 29]],
      
               [[30, 31, 32],
                [33, 34, 35]]]])
      indices torch.Size([2, 3]) tensor([[0, 0, 1],
              [0, 0, 1]])
      result torch.Size([2, 3, 3]) tensor([[[ 0,  1,  2],
               [ 6,  7,  8],
               [15, 16, 17]],
      
              [[18, 19, 20],
               [24, 25, 26],
               [33, 34, 35]]])
      

      【讨论】:

        【解决方案3】:

        在使用收集函数之前,重塑索引,这是一个例子

        def gather_righthand(src, index, check=True):
            index = index.long()
            i_dim = index.dim(); s_dim = src.dim(); t_dim = i_dim-1
            if check:
                assert s_dim > i_dim
                for d in range(0, t_dim): 
                    assert src.shape[d] == index.shape[d]
            index_new_shape = list(src.shape)
            index_new_shape[t_dim] = index.shape[t_dim]
            for _ in range(i_dim, s_dim): index = index.unsqueeze(-1)
        
            index_expand = index.expand(index_new_shape)            # only this two line matters
            return torch.gather(src, dim=t_dim, index=index_expand) # only this two line matters
        
        
        gather_righthand(a, indices)
        tensor([[[ 0.,  1.,  2.]],
                [[ 9., 10., 11.]],
                [[12., 13., 14.]]])
        

        【讨论】:

          猜你喜欢
          • 2021-09-27
          • 2022-07-20
          • 2018-01-24
          • 2016-12-14
          • 1970-01-01
          • 2018-09-02
          • 2020-11-28
          • 2017-02-23
          • 2017-04-09
          相关资源
          最近更新 更多