【问题标题】:Indexing using pyTorch tensors along one specific dimension with 3 dimensional tensor使用 pyTorch 张量沿一个特定维度和 3 维张量进行索引
【发布时间】:2021-06-10 17:47:46
【问题描述】:

我有 2 个张量:

带有形状(批次、序列、词汇)的A 和 B 的形状(批次、序列)。

A = torch.tensor([[[ 1.,  2.,  3.],
     [ 5.,  6.,  7.]],

    [[ 9., 10., 11.],
     [13., 14., 15.]]])

B = torch.tensor([[0, 2],
    [1, 0]])

我想得到以下内容:

C = torch.zeros_like(B)
for i in range(B.shape[0]):
   for j in range(B.shape[1]):
      C[i,j] = A[i,j,B[i,j]]

但是以矢量化的方式。我尝试了 torch.gather 和其他东西,但我无法让它工作。 谁能帮帮我?

【问题讨论】:

    标签: python-3.x numpy indexing pytorch torch


    【解决方案1】:
    >>> import torch
    >>> A = torch.tensor([[[ 1.,  2.,  3.],
    ...      [ 5.,  6.,  7.]],
    ... 
    ...     [[ 9., 10., 11.],
    ...      [13., 14., 15.]]])
    >>> B = torch.tensor([[0, 2],
    ...     [1, 0]])
    >>> A.shape
    torch.Size([2, 2, 3])
    >>> B.shape
    torch.Size([2, 2])
    >>> C = torch.zeros_like(B)
    >>> for i in range(B.shape[0]):
    ...    for j in range(B.shape[1]):
    ...       C[i,j] = A[i,j,B[i,j]]
    ... 
    >>> C
    tensor([[ 1,  7],
            [10, 13]])
    >>> torch.gather(A, -1, B.unsqueeze(-1))
    tensor([[[ 1.],
             [ 7.]],
    
            [[10.],
             [13.]]])
    >>> torch.gather(A, -1, B.unsqueeze(-1)).shape
    torch.Size([2, 2, 1])
    >>> torch.gather(A, -1, B.unsqueeze(-1)).squeeze(-1)
    tensor([[ 1.,  7.],
            [10., 13.]])
    

    您好,您可以使用torch.gather(A, -1, B.unsqueeze(-1)).squeeze(-1)AB.unsqueeze(-1) 之间的第一个 -1 表示您要沿其选取元素的维度。

    B.unsqueeze(-1) 中的第二个 -1 是向 B 添加一个暗度,以使两个张量具有相同的暗度,否则您将得到 RuntimeError: Index tensor must have the same number of dimensions as input tensor

    最后一个-1 是将结果从torch.Size([2, 2, 1]) 重塑为torch.Size([2, 2])

    【讨论】:

    • 非常感谢。这非常有帮助:)
    猜你喜欢
    • 2019-09-15
    • 2021-07-17
    • 2020-09-26
    • 1970-01-01
    • 2019-02-05
    • 2019-02-19
    • 2021-02-10
    • 2021-12-01
    • 2020-10-26
    相关资源
    最近更新 更多