【问题标题】:Indexing second dimension of Tensor using indices使用索引索引张量的第二维
【发布时间】:2017-12-20 15:07:16
【问题描述】:

我使用索引张量在我的张量中选择了元素。下面的代码我使用索引列表 0、3、2、1 来选择 11、15、2、5

>>> import torch
>>> a = torch.Tensor([5,2,11, 15])
>>> torch.randperm(4)

 0
 3
 2
 1
[torch.LongTensor of size 4]

>>> i = torch.randperm(4)
>>> a[i]

 11
 15
  2
  5
[torch.FloatTensor of size 4]

现在,我有

>>> b = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
>>> b

  5   2  11  15
  5   2  11  15
  5   2  11  15
[torch.FloatTensor of size 3x4]

现在,我想使用索引来选择第 0、3、2、1 列。换句话说,我想要一个这样的张量

>>> b

 11  15   2   5
 11  15   2   5
 11  15   2   5
[torch.FloatTensor of size 3x4]

【问题讨论】:

    标签: python indexing pytorch tensor


    【解决方案1】:

    如果使用 pytorch 版本 v0.1.12

    对于这个版本,没有简单的方法可以做到这一点。尽管 pytorch 承诺张量操作与 numpy 完全一样,但仍然缺少一些功能。这是其中之一。

    如果您使用 numpy 数组,通常可以相对轻松地做到这一点。像这样。

    >>> i = [2, 1, 0, 3]
    >>> a = np.array([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
    >>> a[:, i]
    
    array([[11,  2,  5, 15],
           [11,  2,  5, 15],
           [11,  2,  5, 15]])
    

    但是对于张量来说同样的事情会给你一个错误:

    >>> i = torch.LongTensor([2, 1, 0, 3])
    >>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
    >>> a[:,i]
    

    错误:

    TypeError:使用类型为 torch.LongTensor 的对象对张量进行索引。唯一受支持的类型是整数、切片、numpy 标量和 torch.LongTensor 或 torch.ByteTensor 作为唯一参数。

    TypeError 告诉您的是,如果您打算使用 LongTensor 或 ByteTensor 进行索引,那么唯一有效的语法是 a[<LongTensor>]a[<ByteTensor>]。除此之外的任何东西都行不通。

    由于此限制,您有两种选择:

    选项 1: 转换为 numpy,置换,然后返回到 Tensor

    >>> i = [2, 1, 0, 3]
    >>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
    >>> np_a = a.numpy()
    >>> np_a = np_a[:,i]
    >>> a = torch.from_numpy(np_a)
    >>> a
    
     11   2   5  15
     11   2   5  15
     11   2   5  15
    [torch.FloatTensor of size 3x4]
    

    选项 2:将要置换的暗淡移动到 0,然后执行此操作

    您将要置换的暗淡(在您的情况下为 dim=1)移动到 0,执行置换,然后将其移回。它有点 hacky,但它完成了工作。

    def hacky_permute(a, i, dim):
        a = torch.transpose(a, 0, dim)
        a = a[i]
        a = torch.transpose(a, 0, dim)
        return a
    

    然后像这样使用它:

    >>> i = torch.LongTensor([2, 1, 0, 3])
    >>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
    >>> a = hacky_permute(a, i, dim=1)
    >>> a
    
     11   2   5  15
     11   2   5  15
     11   2   5  15
    [torch.FloatTensor of size 3x4]
    

    如果使用 pytorch v0.2.0 版本

    使用张量的直接索引现在可以在此版本中使用。即。

    >>> i = torch.LongTensor([2, 1, 0, 3])
    >>> a = torch.Tensor([[5, 2, 11, 15],[5, 2, 11, 15], [5, 2, 11, 15]])
    >>> a[:,i]
    
     11   2   5  15
     11   2   5  15
     11   2   5  15
    [torch.FloatTensor of size 3x4]
    

    【讨论】:

    • 谢谢。 torch.transpose 是我当前使用的解决方法。
    猜你喜欢
    • 2019-09-15
    • 2020-09-26
    • 1970-01-01
    • 1970-01-01
    • 2021-06-27
    • 2018-06-08
    • 2017-05-04
    • 2021-11-25
    • 2017-08-25
    相关资源
    最近更新 更多