【问题标题】:How does PyTorch Tensor.index_select() evaluates tensor output?PyTorch Tensor.index_select() 如何评估张量输出?
【发布时间】:2021-04-10 11:51:05
【问题描述】:

我无法理解索引的复杂性——张量的非连续索引是如何工作的。这是一个示例代码及其输出

import torch

def describe(x):
  print("Type: {}".format(x.type()))
  print("Shape/size: {}".format(x.shape))
  print("Values: \n{}".format(x))


indices = torch.LongTensor([0,2])
x = torch.arange(6).view(2,3)
describe(torch.index_select(x, dim=1, index=indices))

返回输出为

类型:torch.LongTensor 形状/大小:torch.Size([2, 2]) 值: 张量([[0, 2], [3, 5]])

有人能解释一下它是如何到达这个输出张量的吗? 谢谢!

【问题讨论】:

    标签: python indexing pytorch tensor


    【解决方案1】:

    您正在从第一个轴 (dim=0) 上的 x 中选择第一个 (indices[0]0) 和第三个 (indices[1]2) 张量。从本质上讲,torch.index_selectdim=1 的工作方式与使用 x[:, indices] 在第二个轴上进行直接索引相同。

    >>> x
    tensor([[0, 1, 2],
            [3, 4, 5]])
    

    因此选择列(因为您查看的是dim=1 而不是dim=0)哪些索引在indices 中。想象一下有一个简单的列表 [0, 2] 作为indices

    >>> indices = [0, 2]
    
    >>> x[:, indices[0]] # same as x[:, 0]
    tensor([0, 3])
    
    >>> x[:, indices[1]] # same as x[:, 2]
    tensor([2, 5])
    

    因此,将索引作为torch.Tensor 传递允许您直接索引索引的所有元素,即列02。类似于 NumPy 的索引工作原理。

    >>> x[:, indices]
    tensor([[0, 2],
            [3, 5]])
    

    这是另一个示例,可帮助您了解其工作原理。将x 定义为x = torch.arange(9).view(3, 3),所以我们有3 行(又名dim=0)和3 列(又名dim=1)。

    >>> indices
    tensor([0, 2]) # namely 'first' and 'third'
    
    >>> x = torch.arange(9).view(3, 3)
    tensor([[0, 1, 2],
            [3, 4, 5],
            [6, 7, 8]])
    
    >>> x.index_select(0, indices) # select first and third rows
    tensor([[0, 1, 2],
            [6, 7, 8]])
    
    >>> x.index_select(1, indices) # select first and third columns
    tensor([[0, 2],
            [3, 5],
            [6, 8]])
    

    注意torch.index_select(x, dim, indices) 等价于x.index_select(dim, indices)

    【讨论】:

      猜你喜欢
      • 2022-10-18
      • 2018-05-10
      • 2019-08-02
      • 2019-01-26
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2021-12-07
      • 1970-01-01
      相关资源
      最近更新 更多