【问题标题】:How Pytorch Tensor get the index of specific valuePytorch Tensor 如何获取特定值的索引
【发布时间】:2018-05-31 11:21:53
【问题描述】:

在python列表中,我们可以使用list.index(somevalue)。 pytorch 如何做到这一点?
例如:

    a=[1,2,3]
    print(a.index(2))

然后,1 将被输出。 pytorch 张量如何在不将其转换为 python 列表的情况下做到这一点?

【问题讨论】:

标签: python pytorch


【解决方案1】:

我认为没有从 list.index() 直接转换为 pytorch 函数。但是,您可以使用tensor==numbernonzero() 函数获得类似的结果。例如:

t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero(as_tuple=True)[0])

这段代码返回

1

[torch.LongTensor 大小为 1x1]

【讨论】:

  • 如果没有匹配会发生什么?你会怎么处理呢?
  • @CharlieParker 如果没有匹配则返回空张量tensor([], dtype=torch.int64)
  • 我们如何扩展它以获得批量索引?在这种情况下,我想要一次索引值torch.Tensor([1, 2, 3]),而不仅仅是2。有没有没有for循环的方法?
【解决方案2】:

对于多维张量,您可以这样做:

(tensor == target_value).nonzero(as_tuple=True)

生成的张量的形状为number_of_matches x tensor_dimension。例如,假设 tensor 是一个 3 x 4 张量(这意味着维度是 2),结果将是一个二维张量,其中包含行中匹配项的索引。

tensor = torch.Tensor([[1, 2, 2, 7], [3, 1, 2, 4], [3, 1, 9, 4]])
(tensor == 2).nonzero(as_tuple=False)
>>> tensor([[0, 1],
        [0, 2],
        [1, 2]])

【讨论】:

  • 最完整的答案!一般不仅仅是平坦张量。
  • 如果没有匹配会发生什么?你会怎么处理呢?
  • 在这种情况下,你会得到一个空的张量。张量形状仍将遵循输入张量的尺寸,因此在上面的示例中,搜索8 将导致形状为0 x 2 的(空)张量。
【解决方案3】:

根据其他人的回答:

t = torch.Tensor([1, 2, 3])
print((t==1).nonzero().item())

【讨论】:

  • 如果张量只包含一个预期数字的出现,这没关系。那是因为.item()方法只能在单元素张量上调用,否则会报错。
【解决方案4】:

可以通过如下转换为numpy来完成

import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1.,  2.,  3.,  4.]) 
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2

【讨论】:

  • 请注意,如果您转换为 numpy,则会丢失梯度图。
  • 如果没有匹配会发生什么?你会怎么处理呢?
【解决方案5】:

已经给出的答案很好,但是当我尝试没有匹配时它们无法处理。为此,请参阅:

def index(tensor: Tensor, value, ith_match:int =0) -> Tensor:
    """
    Returns generalized index (i.e. location/coordinate) of the first occurence of value
    in Tensor. For flat tensors (i.e. arrays/lists) it returns the indices of the occurrences
    of the value you are looking for. Otherwise, it returns the "index" as a coordinate.
    If there are multiple occurences then you need to choose which one you want with ith_index.
    e.g. ith_index=0 gives first occurence.

    Reference: https://stackoverflow.com/a/67175757/1601580
    :return:
    """
    # bool tensor of where value occurred
    places_where_value_occurs = (tensor == value)
    # get matches as a "coordinate list" where occurence happened
    matches = (tensor == value).nonzero()  # [number_of_matches, tensor_dimension]
    if matches.size(0) == 0:  # no matches
        return -1
    else:
        # get index/coordinate of the occurence you want (e.g. 1st occurence ith_match=0)
        index = matches[ith_match]
        return index

感谢这个出色的答案:https://stackoverflow.com/a/67175757/1601580

【讨论】:

    【解决方案6】:

    用于查找一维张量/数组中元素的索引 示例

    mat=torch.tensor([1,8,5,3])
    

    找到5的索引

    five=5
    
    numb_of_col=4
    for o in range(numb_of_col):
       if mat[o]==five:
         print(torch.tensor([o]))
    

    要查找 2d/3d 张量的元素索引,将其转换为 1d #ie example.view(元素个数)

    例子

    mat=torch.tensor([[1,2],[4,3])
    #to find index of 2
    
    five = 2
    mat=mat.view(4)
    numb_of_col = 4
    for o in range(numb_of_col):
       if mat[o] == five:
         print(torch.tensor([o]))    
    

    【讨论】:

    • 这是一个老问题,OP 很可能不会寻找快速答案。如果您想发布答案,请花点时间解释您的代码的作用以及它如何添加到已经存在的答案中。现在的代码块对社区几乎没有价值
    【解决方案7】:

    对于浮点张量,我用这个来获取张量中元素的索引。

    print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())
    

    这里我想获取max_value在float tensor中的索引,你也可以把你的value这样放到tensor中任意元素的索引。

    print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())
    

    【讨论】:

      【解决方案8】:
          import torch
          x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
          print(x_data.data[0])
          >>tensor([1.])
      

      【讨论】:

        猜你喜欢
        • 2020-01-15
        • 2021-01-25
        • 2019-01-13
        • 2020-03-17
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多