【发布时间】:2020-01-27 21:37:07
【问题描述】:
对于给定的二维张量,我想检索值为1 的所有索引。我希望能够简单地使用torch.nonzero(a == 1).squeeze(),这将返回tensor([1, 3, 2])。然而,相反,torch.nonzero(a == 1) 返回一个二维张量(没关系),每行有两个值(这不是我所期望的)。然后应使用返回的索引来索引 3D 张量的第二维(索引 1),再次返回 2D 张量。
import torch
a = torch.Tensor([[12, 1, 0, 0],
[4, 9, 21, 1],
[10, 2, 1, 0]])
b = torch.rand(3, 4, 8)
print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])
idxs = torch.nonzero(a == 1)
print('idxs_size', idxs.size())
# idxs_size torch.Size([3, 2])
print(b.gather(1, idxs))
显然,这不起作用,导致一个RunTimeError:
RuntimeError: invalid argument 4: 索引张量必须相同 尺寸作为输入张量 C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensorEvenMoreMath.cpp:453
看来idxs不是我想的那样,也不能像我想的那样使用。 idxs是
tensor([[0, 1],
[1, 3],
[2, 2]])
但是通读documentation 我不明白为什么我还要取回结果张量中的行索引。现在,我知道我可以通过切片idxs[:, 1] 来获得正确的 idx,但是我仍然不能将这些值用作 3D 张量的索引,因为会引发与以前相同的错误。是否可以使用索引的一维张量来选择给定维度的项目?
【问题讨论】:
标签: python multidimensional-array pytorch tensor tensor-indexing