【问题标题】:Extracting tensor data with index in pytorch在pytorch中使用索引提取张量数据
【发布时间】:2021-11-12 05:56:13
【问题描述】:

我想让张量以某种方式索引。

假设我的数据,张量 X 形 (1, 3, 16, 9)

tensor([[[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
      [ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
      [ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
      [ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
      [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
      [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
      [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
      [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
      [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
      [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
      [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
      [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
      [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.],
      [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.],
      [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.],
      [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]],

     [[ 0.,  0.,  0.,  0., 17., 18.,  0., 21., 22.],
      [ 0.,  0.,  0., 17., 18., 19., 21., 22., 23.],
      [ 0.,  0.,  0., 18., 19., 20., 22., 23., 24.],
      [ 0.,  0.,  0., 19., 20.,  0., 23., 24.,  0.],
      [ 0., 17., 18.,  0., 21., 22.,  0., 25., 26.],
      [17., 18., 19., 21., 22., 23., 25., 26., 27.],
      [18., 19., 20., 22., 23., 24., 26., 27., 28.],
      [19., 20.,  0., 23., 24.,  0., 27., 28.,  0.],
      [ 0., 21., 22.,  0., 25., 26.,  0., 29., 30.],
      [21., 22., 23., 25., 26., 27., 29., 30., 31.],
      [22., 23., 24., 26., 27., 28., 30., 31., 32.],
      [23., 24.,  0., 27., 28.,  0., 31., 32.,  0.],
      [ 0., 25., 26.,  0., 29., 30.,  0.,  0.,  0.],
      [25., 26., 27., 29., 30., 31.,  0.,  0.,  0.],
      [26., 27., 28., 30., 31., 32.,  0.,  0.,  0.],
      [27., 28.,  0., 31., 32.,  0.,  0.,  0.,  0.]],

     [[ 0.,  0.,  0.,  0., 33., 34.,  0., 37., 38.],
      [ 0.,  0.,  0., 33., 34., 35., 37., 38., 39.],
      [ 0.,  0.,  0., 34., 35., 36., 38., 39., 40.],
      [ 0.,  0.,  0., 35., 36.,  0., 39., 40.,  0.],
      [ 0., 33., 34.,  0., 37., 38.,  0., 41., 42.],
      [33., 34., 35., 37., 38., 39., 41., 42., 43.],
      [34., 35., 36., 38., 39., 40., 42., 43., 44.],
      [35., 36.,  0., 39., 40.,  0., 43., 44.,  0.],
      [ 0., 37., 38.,  0., 41., 42.,  0., 45., 46.],
      [37., 38., 39., 41., 42., 43., 45., 46., 47.],
      [38., 39., 40., 42., 43., 44., 46., 47., 48.],
      [39., 40.,  0., 43., 44.,  0., 47., 48.,  0.],
      [ 0., 41., 42.,  0., 45., 46.,  0.,  0.,  0.],
      [41., 42., 43., 45., 46., 47.,  0.,  0.,  0.],
      [42., 43., 44., 46., 47., 48.,  0.,  0.,  0.],
      [43., 44.,  0., 47., 48.,  0.,  0.,  0.,  0.]]]]

我希望将(row_index % n) == i(比如n = 4i = 0 to 3)保存在另一个张量Y中的那些行。

例如对于数据X[0][0]

[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
 [ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
 [ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
 [ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
 [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
 [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
 [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
 [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
 [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
 [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
 [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
 [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
 [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.],
 [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.],
 [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.],      
 [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]]

我想要一个包含以下数据的张量,它基本上是row_index % 4 == 0(这里是i = 0)所在的行的集合:

[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
 [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
 [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
 [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.]]

同样,i = 1row_index % 4 == i 看起来像:

[[ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
 [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
 [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
 [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.]]

i = 2row_index % 4 == i

[[ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
 [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
 [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
 [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.]]

i = 3row_index % 4 == i

[[ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
 [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
 [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
 [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]]

我尝试过对其进行硬编码,但当数据变得更大并且大小变得动态时,它似乎并不实用,我认为会有更好的方法来实现它。

temp0 = data[0][0][0][:] 
temp1 = data[0][0][4][:]
temp2 = data[0][0][8][:]
temp3 = data[0][0][12][:]
temp = torch.stack([temp0,temp1,temp2,temp3],dim = 0)

另外,如果结果可以返回一个张量,那就太好了:

tensor Y = ([[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
              [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
              [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
              [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.]], 

             [[ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
              [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
              [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
              [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.]], 
   
             [[ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
              [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
              [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
              [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.]], 

             [[ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
              [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
              [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
              [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]]])

【问题讨论】:

    标签: pytorch data-manipulation tensor


    【解决方案1】:

    首先,要获得每个爱国,你可以试试这个:

    import torch
    
    data = torch.tensor([[[[0., 0., 0., 0., 1., 2., 0., 5., 6.],
                           [0., 0., 0., 1., 2., 3., 5., 6., 7.],
                           [0., 0., 0., 2., 3., 4., 6., 7., 8.],
                           [0., 0., 0., 3., 4., 0., 7., 8., 0.],
                           [0., 1., 2., 0., 5., 6., 0., 9., 10.],
                           [1., 2., 3., 5., 6., 7., 9., 10., 11.],
                           [2., 3., 4., 6., 7., 8., 10., 11., 12.],
                           [3., 4., 0., 7., 8., 0., 11., 12., 0.],
                           [0., 5., 6., 0., 9., 10., 0., 13., 14.],
                           [5., 6., 7., 9., 10., 11., 13., 14., 15.],
                           [6., 7., 8., 10., 11., 12., 14., 15., 16.],
                           [7., 8., 0., 11., 12., 0., 15., 16., 0.],
                           [0., 9., 10., 0., 13., 14., 0., 0., 0.],
                           [9., 10., 11., 13., 14., 15., 0., 0., 0.],
                           [10., 11., 12., 14., 15., 16., 0., 0., 0.],
                           [11., 12., 0., 15., 16., 0., 0., 0., 0.]],
    
                          [[0., 0., 0., 0., 17., 18., 0., 21., 22.],
                           [0., 0., 0., 17., 18., 19., 21., 22., 23.],
                           [0., 0., 0., 18., 19., 20., 22., 23., 24.],
                           [0., 0., 0., 19., 20., 0., 23., 24., 0.],
                           [0., 17., 18., 0., 21., 22., 0., 25., 26.],
                           [17., 18., 19., 21., 22., 23., 25., 26., 27.],
                           [18., 19., 20., 22., 23., 24., 26., 27., 28.],
                           [19., 20., 0., 23., 24., 0., 27., 28., 0.],
                           [0., 21., 22., 0., 25., 26., 0., 29., 30.],
                           [21., 22., 23., 25., 26., 27., 29., 30., 31.],
                           [22., 23., 24., 26., 27., 28., 30., 31., 32.],
                           [23., 24., 0., 27., 28., 0., 31., 32., 0.],
                           [0., 25., 26., 0., 29., 30., 0., 0., 0.],
                           [25., 26., 27., 29., 30., 31., 0., 0., 0.],
                           [26., 27., 28., 30., 31., 32., 0., 0., 0.],
                           [27., 28., 0., 31., 32., 0., 0., 0., 0.]],
    
                          [[0., 0., 0., 0., 33., 34., 0., 37., 38.],
                           [0., 0., 0., 33., 34., 35., 37., 38., 39.],
                           [0., 0., 0., 34., 35., 36., 38., 39., 40.],
                           [0., 0., 0., 35., 36., 0., 39., 40., 0.],
                           [0., 33., 34., 0., 37., 38., 0., 41., 42.],
                           [33., 34., 35., 37., 38., 39., 41., 42., 43.],
                           [34., 35., 36., 38., 39., 40., 42., 43., 44.],
                           [35., 36., 0., 39., 40., 0., 43., 44., 0.],
                           [0., 37., 38., 0., 41., 42., 0., 45., 46.],
                           [37., 38., 39., 41., 42., 43., 45., 46., 47.],
                           [38., 39., 40., 42., 43., 44., 46., 47., 48.],
                           [39., 40., 0., 43., 44., 0., 47., 48., 0.],
                           [0., 41., 42., 0., 45., 46., 0., 0., 0.],
                           [41., 42., 43., 45., 46., 47., 0., 0., 0.],
                           [42., 43., 44., 46., 47., 48., 0., 0., 0.],
                           [43., 44., 0., 47., 48., 0., 0., 0., 0.]]]])
    
    print(data.shape)
    
    n, i = 4, 0
    indices = [index for index in range(data.shape[2]) if index % n == i]
    print(data[0, 0, indices])
    

    对于这些张量的组合,您可以尝试使用:

    n = 4
    result = []
    for i in range(n):
        indices = [index for index in range(data.shape[2]) if index % n == i]
        result.append(data[0, 0, indices])
    
    final = torch.stack(result, dim=0)
    

    【讨论】:

    • 太好了,让我知道这是否适合您。如果是,请将此问题标记为已回答:)
    • 谢谢,成功了! :D
    【解决方案2】:

    您可以通过首先构造一个包含所选行的张量,然后使用torch.gather 组装最终张量来实现此目的。

    假设我们有两个 lists IN 分别包含 in 的值:

    I = [0, 1, 2, 3]
    N = [4, 4, 4, 4]
    

    首先我们构造索引张量:

    >>> index = torch.stack([(torch.arange(16) % n == i).nonzero() for i, n in zip(I, N)])
    tensor([[[ 0],
             [ 4],
             [ 8],
             [12]],
    
            [[ 1],
             [ 5],
             [ 9],
             [13]],
    
            [[ 2],
             [ 6],
             [10],
             [14]],
    
            [[ 3],
             [ 7],
             [11],
             [15]]])
    

    然后需要一些扩展和重塑:

    >>> index_ = index[None].flatten(1,2).expand(X.size(0), -1, X.size(-1))
    tensor([[[ 0,  0,  0,  0,  0,  0,  0,  0,  0],
             [ 4,  4,  4,  4,  4,  4,  4,  4,  4],
             [ 8,  8,  8,  8,  8,  8,  8,  8,  8],
             [12, 12, 12, 12, 12, 12, 12, 12, 12],
             [ 1,  1,  1,  1,  1,  1,  1,  1,  1],
             [ 5,  5,  5,  5,  5,  5,  5,  5,  5],
             [ 9,  9,  9,  9,  9,  9,  9,  9,  9],
             [13, 13, 13, 13, 13, 13, 13, 13, 13],
             [ 2,  2,  2,  2,  2,  2,  2,  2,  2],
             [ 6,  6,  6,  6,  6,  6,  6,  6,  6],
             [10, 10, 10, 10, 10, 10, 10, 10, 10],
             [14, 14, 14, 14, 14, 14, 14, 14, 14],
             [ 3,  3,  3,  3,  3,  3,  3,  3,  3],
             [ 7,  7,  7,  7,  7,  7,  7,  7,  7],
             [11, 11, 11, 11, 11, 11, 11, 11, 11],
             [15, 15, 15, 15, 15, 15, 15, 15, 15]]])
    

    根据经验,我们希望index_ 具有与X 相同的维度数。

    现在我们可以申请 torch.gather 并重塑最终形式:

    >>> X.gather(1, index_).reshape(len(X), *index.shape[:2], -1)
    tensor([[[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
              [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
              [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
              [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.]],
    
             [[ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
              [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
              [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
              [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.]],
    
             [[ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
              [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
              [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
              [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.]],
    
             [[ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
              [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
              [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
              [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]]]])
    

    这个方法可以扩展到批量张量:

    >>> index = torch.stack([(torch.arange(16) % n == i).nonzero() for i, n in zip(I, N)])
    >>> index_  = index[None,None].flatten(2,3).expand(X.size(0), X.size(1), -1, X.size(-1))
    
    >>> X.gather(2, index_).reshape(*X.shape[:2], *index.shape[:2], -1)
    

    【讨论】:

    • 当我运行 "X.gather(1, index_).reshape(len(X), *index.shape[:2], -1)" 时出现此错误:RuntimeError: Index tensor必须与输入张量具有相同的维数。还有什么好文章或课程可供查找,因为我似乎无法可视化数据操作功能并且仍在努力理解您的代码谢谢
    • 输入张量X的形状是什么?如果有关广播/重塑方法的详细信息,我将编辑我的问题。
    • 感谢您回来。批处理扩展有效,似乎扩展 index_ 为第一部分做了诀窍。 "index_ = index_.view(index_.size(0),1,index_.size(1),index_.size(2)) index_ = index_.expand(-1,X.size(1),-1,- 1)"。再次感谢您
    猜你喜欢
    • 2019-02-05
    • 2021-12-16
    • 2020-03-28
    • 2019-11-26
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多