【问题标题】:PyTorch indexing: select complement of indicesPyTorch 索引:选择索引的补充
【发布时间】:2021-07-13 09:47:42
【问题描述】:

假设我有一个张量和索引:

x = torch.tensor([1,2,3,4,5])
idx = torch.tensor([0,2,4])

如果我想选择索引中不是的所有元素,我可以manually define a Boolean mask 像这样:

mask = torch.ones_like(x)
mask[idx] = 0

x[mask]

有没有更优雅的方式来做到这一点?

即我可以直接传递索引而不是创建掩码的语法,例如类似:

x[~idx]

【问题讨论】:

    标签: python numpy pytorch matrix-indexing


    【解决方案1】:

    你可能想试试单行表达式:

    x[np.setdiff1d(range(len(x)), idx)]
    

    虽然看起来也不优雅:)。

    【讨论】:

      【解决方案2】:

      我找不到一个令人满意的解决方案来找到多维索引张量的补码,最后实现了我自己的解决方案。它可以在 cuda 上工作,并享受快速的并行计算。

      def complement_idx(idx, dim):
          """
          Compute the complement: set(range(dim)) - set(idx).
          idx is a multi-dimensional tensor, find the complement for its trailing dimension,
          all other dimension is considered batched.
          Args:
              idx: input index, shape: [N, *, K]
              dim: the max index for complement
          """
          a = torch.arange(dim, device=idx.device)
          ndim = idx.ndim
          dims = idx.shape
          n_idx = dims[-1]
          dims = dims[:-1] + (-1, )
          for i in range(1, ndim):
              a = a.unsqueeze(0)
          a = a.expand(*dims)
          masked = torch.scatter(a, -1, idx, 0)
          compl, _ = torch.sort(masked, dim=-1, descending=False)
          compl = compl.permute(-1, *tuple(range(ndim - 1)))
          compl = compl[n_idx:].permute(*(tuple(range(1, ndim)) + (0,)))
          return compl
      

      例子:

      >>> import torch
      >>> a = torch.rand(3, 4, 5)
      >>> a
      tensor([[[0.7849, 0.7404, 0.4112, 0.9873, 0.2937],
               [0.2113, 0.9923, 0.6895, 0.1360, 0.2952],
               [0.9644, 0.9577, 0.2021, 0.6050, 0.7143],
               [0.0239, 0.7297, 0.3731, 0.8403, 0.5984]],
      
              [[0.9089, 0.0945, 0.9573, 0.9475, 0.6485],
               [0.7132, 0.4858, 0.0155, 0.3899, 0.8407],
               [0.2327, 0.8023, 0.6278, 0.0653, 0.2215],
               [0.9597, 0.5524, 0.2327, 0.1864, 0.1028]],
      
              [[0.2334, 0.9821, 0.4420, 0.1389, 0.2663],
               [0.6905, 0.2956, 0.8669, 0.6926, 0.9757],
               [0.8897, 0.4707, 0.5909, 0.6522, 0.9137],
               [0.6240, 0.1081, 0.6404, 0.1050, 0.6413]]])
      >>> b, c = torch.topk(a, 2, dim=-1)
      >>> b
      tensor([[[0.9873, 0.7849],
               [0.9923, 0.6895],
               [0.9644, 0.9577],
               [0.8403, 0.7297]],
      
              [[0.9573, 0.9475],
               [0.8407, 0.7132],
               [0.8023, 0.6278],
               [0.9597, 0.5524]],
      
              [[0.9821, 0.4420],
               [0.9757, 0.8669],
               [0.9137, 0.8897],
               [0.6413, 0.6404]]])
      >>> c
      tensor([[[3, 0],
               [1, 2],
               [0, 1],
               [3, 1]],
      
              [[2, 3],
               [4, 0],
               [1, 2],
               [0, 1]],
      
              [[1, 2],
               [4, 2],
               [4, 0],
               [4, 2]]])
      >>> compl = complement_idx(c, 5)
      >>> compl
      tensor([[[1, 2, 4],
               [0, 3, 4],
               [2, 3, 4],
               [0, 2, 4]],
      
              [[0, 1, 4],
               [1, 2, 3],
               [0, 3, 4],
               [2, 3, 4]],
      
              [[0, 3, 4],
               [0, 1, 3],
               [1, 2, 3],
               [0, 1, 3]]])
      >>> al = torch.cat([c, compl], dim=-1)
      >>> al
      tensor([[[3, 0, 1, 2, 4],
               [1, 2, 0, 3, 4],
               [0, 1, 2, 3, 4],
               [3, 1, 0, 2, 4]],
      
              [[2, 3, 0, 1, 4],
               [4, 0, 1, 2, 3],
               [1, 2, 0, 3, 4],
               [0, 1, 2, 3, 4]],
      
              [[1, 2, 0, 3, 4],
               [4, 2, 0, 1, 3],
               [4, 0, 1, 2, 3],
               [4, 2, 0, 1, 3]]])
      >>> al, _ = al.sort(dim=-1)
      >>> al
      tensor([[[0, 1, 2, 3, 4],
               [0, 1, 2, 3, 4],
               [0, 1, 2, 3, 4],
               [0, 1, 2, 3, 4]],
      
              [[0, 1, 2, 3, 4],
               [0, 1, 2, 3, 4],
               [0, 1, 2, 3, 4],
               [0, 1, 2, 3, 4]],
      
              [[0, 1, 2, 3, 4],
               [0, 1, 2, 3, 4],
               [0, 1, 2, 3, 4],
               [0, 1, 2, 3, 4]]])
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 2018-05-02
        • 2018-05-06
        • 2022-11-24
        • 2013-06-09
        • 2019-05-27
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多