【问题标题】:PyTorch tensors topk for every tensor across a dimensionPyTorch 张量 topk 为跨维度的每个张量
【发布时间】:2021-04-02 00:18:50
【问题描述】:

我有以下张量

inp = tensor([[[ 0.0000e+00,  5.7100e+02, -6.9846e+00],
     [ 0.0000e+00,  4.4070e+03, -7.1008e+00],
     [ 0.0000e+00,  3.0300e+02, -7.2226e+00],
     [ 0.0000e+00,  6.8000e+01, -7.2777e+00],
     [ 1.0000e+00,  5.7100e+02, -6.9846e+00],
     [ 1.0000e+00,  4.4070e+03, -7.1008e+00],
     [ 1.0000e+00,  3.0300e+02, -7.2226e+00],
     [ 1.0000e+00,  6.8000e+01, -7.2777e+00]],

    [[ 0.0000e+00,  2.1610e+03, -7.0754e+00],
     [ 0.0000e+00,  6.8000e+01, -7.2259e+00],
     [ 0.0000e+00,  1.0620e+03, -7.2920e+00],
     [ 0.0000e+00,  2.9330e+03, -7.3009e+00],
     [ 1.0000e+00,  2.1610e+03, -7.0754e+00],
     [ 1.0000e+00,  6.8000e+01, -7.2259e+00],
     [ 1.0000e+00,  1.0620e+03, -7.2920e+00],
     [ 1.0000e+00,  2.9330e+03, -7.3009e+00]],

    [[ 0.0000e+00,  4.4070e+03, -7.1947e+00],
     [ 0.0000e+00,  3.5600e+02, -7.2958e+00],
     [ 0.0000e+00,  3.0300e+02, -7.3232e+00],
     [ 0.0000e+00,  1.2910e+03, -7.3615e+00],
     [ 1.0000e+00,  4.4070e+03, -7.1947e+00],
     [ 1.0000e+00,  3.5600e+02, -7.2958e+00],
     [ 1.0000e+00,  3.0300e+02, -7.3232e+00],
     [ 1.0000e+00,  1.2910e+03, -7.3615e+00]]])

形状

torch.Size([3, 8, 3])

我想在 dim1 中找到 topk(k=4) 元素,其中要排序的值是 dim2(负值)。得到的张量形状应该是:

torch.Size([3, 4, 3])

我知道如何对单个张量进行 topk,但是如何一次对多个批次进行此操作?

【问题讨论】:

    标签: python pytorch tensor


    【解决方案1】:

    我是这样做的:

    val, ind = inp[:, :, 2].squeeze().topk(k=4, dim=1, sorted=True)
    new_ind = ind.unsqueeze(-1).repeat(1,1,3)
    result = inp.gather(1, new_ind)
    

    我不知道这是否是最好的方法,但它确实有效。

    【讨论】:

    • 使用gather() 方法的好例子。不错。
    【解决方案2】:

    一种方法是将fancy indexingbroadcasting 组合如下:

    我以形状为(3, 4, 3)k 的随机张量x 为2 为例。

    >>> import torch
    >>> x = torch.rand(3, 4, 3)
    >>> x
    tensor([[[0.0256, 0.7366, 0.2528],
             [0.5596, 0.9450, 0.5795],
             [0.8265, 0.5469, 0.8304],
             [0.4223, 0.5206, 0.2898]],
    
            [[0.2159, 0.0369, 0.6869],
             [0.4556, 0.5804, 0.3169],
             [0.8194, 0.5240, 0.0055],
             [0.8357, 0.4162, 0.3740]],
    
            [[0.3849, 0.0223, 0.9951],
             [0.2872, 0.5952, 0.6570],
             [0.1433, 0.8450, 0.6557],
             [0.0270, 0.9176, 0.3904]]])
    

    现在按照所需维度(最后一个)对张量进行排序并获取索引:

    >>> _, idx = torch.sort(x[:, :, -1])
    >>> k = 2
    >>> idx = idx[:, :k]
    # idx is = 
    tensor([[0, 3],
            [2, 1],
            [3, 2]])
    

    现在生成三对索引(i, j, k) 对原始张量进行切片,如下所示:

    >>> i = torch.arange(x.shape[0]).reshape(x.shape[0], 1, 1)
    >>> j = idx.reshape(x.shape[0], -1, 1)
    >>> k = torch.arange(x.shape[2]).reshape(1, 1, x.shape[2])
    

    请注意,一旦您通过(i, j, k) 索引任何内容,它们将转到expand 并采用(x.shape[0], k, x.shape[2]) 形状,这是此处所需的输出形状。 现在只需按 i、j 和 k 索引 x

    >>> x[i, j, k]
    tensor([[[0.0256, 0.7366, 0.2528],
             [0.4223, 0.5206, 0.2898]],
    
            [[0.8194, 0.5240, 0.0055],
             [0.4556, 0.5804, 0.3169]],
    
            [[0.0270, 0.9176, 0.3904],
             [0.1433, 0.8450, 0.6557]]])
    

    基本上,我遵循的一般方法是通过索引数组创建张量的相应访问模式,然后使用这些数组作为索引直接对张量进行切片。

    实际上我这样做是为了升序排序,所以在这里我得到了 top-k 最少的元素。一个简单的解决方法是使用torch.sort(x[:, :, -1], descending = True)

    【讨论】:

      猜你喜欢
      • 2021-10-23
      • 2017-09-05
      • 1970-01-01
      • 2020-09-26
      • 2020-08-31
      • 2022-07-26
      • 2021-10-11
      • 2020-08-26
      • 1970-01-01
      相关资源
      最近更新 更多