【问题标题】:Pytorch tensor get the index of the element with specific values?Pytorch张量获取具有特定值的元素的索引?
【发布时间】:2021-01-25 18:17:21
【问题描述】:

我有两个张量,张量 a 和张量 b。

我想获取张量 b 中值的所有索引。

例如。

a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])

我想要张量 a 中 1, 2, 4 的索引。我可以通过以下代码做到这一点。

a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])
mask = torch.zeros(a.shape).type(torch.bool)
print(mask)
for e in b:
    mask = mask + (a == e)
    print(mask)

没有for怎么办?

【问题讨论】:

    标签: python numpy tensorflow pytorch


    【解决方案1】:

    如果你不想使用 for 循环,你可以使用列表推导:

    mask = [a[index] for index in b]
    
    

    如果甚至不想使用“for”字,您可以随时将张量转换为 numpy 并使用 numpy 索引。

    mask = torch.tensor(a.numpy()[b.numpy()])
    

    更新

    可能误解了您的问题。在这种情况下,我想说实现这一点的最佳方法是通过列表理解。 (切片可能无法做到这一点。

    mask = [index for index,value in enumerate(a) if value in b.tolist()] 
    

    这会遍历 a 中的每个元素,获取它们的索引和值,如果值在 b 中,则获取索引。

    【讨论】:

    • 我不想在 b 中获取索引值。我想要做的是获取值在 b 中的元素的索引。
    • 更新了答案。我一定误解了你的问题。
    【解决方案2】:

    更新:

    正如@zaydh 在 cmets 中指出的那样,由于 PyTorch 1.10isin()isinf()(以及许多其他 numpy 等效项)也可用,因此您可以简单地这样做:

    torch.isin(a, b)
    

    这会给你:

    Out[4]: tensor([ True,  True,  True, False,  True,  True,  True, False])
    

    旧答案:

    这是你想要的吗? :

    np.in1d(a.numpy(), b.numpy())
    

    将导致:

    array([ True,  True,  True, False,  True,  True,  True, False])
    

    【讨论】:

    • 好的,非常感谢。我可以在 PyTorch 中做到吗(不要将张量更改为 NumPy 数组)?
    • Pytorch 没有实现in1d,在 pytorch 中使用应该不会有任何问题。是什么在烦你? numpy 版本也使用底层数据,所以如果您担心复制等,则没有!
    • 顺便说一句,从torch 1.10 开始,isin 具有本机支持,因此如果您使用该版本或更高版本,则无需转换为numpy
    • @ZaydH 感谢您的更新,将新信息添加到答案中。
    猜你喜欢
    • 2021-05-25
    • 2019-01-13
    • 2021-09-14
    • 2020-05-11
    • 1970-01-01
    • 1970-01-01
    • 2019-05-11
    • 2022-11-24
    • 2018-05-31
    相关资源
    最近更新 更多