【问题标题】:How to shift columns (or rows) in a tensor with different offsets in PyTorch?如何在 PyTorch 中移动具有不同偏移量的张量中的列(或行)?
【发布时间】:2021-06-10 06:46:13
【问题描述】:

在 PyTorch 中,内置的 torch.roll 函数只能移动具有相同偏移量的列(或行)。但我想用不同的偏移量移动列。假设输入张量是

[[1,2,3],
 [4,5,6],
 [7,8,9]]

假设,我想为第 i 列移动偏移量 i。因此,预期的输出是

[[1,8,6],
 [4,2,9],
 [7,5,3]]

这样做的一个选项是使用torch.roll 分别移动每一列并连接它们中的每一个。但是出于有效性和代码紧凑性的考虑,我不想介绍循环结构。有没有更好的办法?

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    我对@9​​87654322@的性能持怀疑态度,所以我用numpy搜索了类似的问题,发现this的帖子。

    从 NumPy 到 Pytorch 的类似解决方案

    我从@Andy L 那里得到了解决方案并将其翻译成pytorch。但是,请持保留态度,因为我不知道大步是如何工作的:

    from numpy.lib.stride_tricks import as_strided
    # NumPy solution:
    def custom_roll(arr, r_tup):
        m = np.asarray(r_tup)
        arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].copy() #need `copy`
        #print(arr_roll)
        strd_0, strd_1 = arr_roll.strides
        #print(strd_0, strd_1)
        n = arr.shape[1]
        result = as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1))
    
        return result[np.arange(arr.shape[0]), (n-m)%n]
    
    # Translated to PyTorch
    def pcustom_roll(arr, r_tup):
        m = torch.tensor(r_tup)
        arr_roll = arr[:, [*range(arr.shape[1]),*range(arr.shape[1]-1)]].clone() #need `copy`
        #print(arr_roll)
        strd_0, strd_1 = arr_roll.stride()
        #print(strd_0, strd_1)
        n = arr.shape[1]
        result = torch.as_strided(arr_roll, (*arr.shape, n), (strd_0 ,strd_1, strd_1))
    
        return result[torch.arange(arr.shape[0]), (n-m)%n]
    

    这也是@Daniel M 的即插即用解决方案。

    def roll_by_gather(mat,dim, shifts: torch.LongTensor):
        # assumes 2D array
        n_rows, n_cols = mat.shape
        
        if dim==0:
            #print(mat)
            arange1 = torch.arange(n_rows).view((n_rows, 1)).repeat((1, n_cols))
            #print(arange1)
            arange2 = (arange1 - shifts) % n_rows
            #print(arange2)
            return torch.gather(mat, 0, arange2)
        elif dim==1:
            arange1 = torch.arange(n_cols).view(( 1,n_cols)).repeat((n_rows,1))
            #print(arange1)
            arange2 = (arange1 - shifts) % n_cols
            #print(arange2)
            return torch.gather(mat, 1, arange2)
        
    

    基准测试

    首先,我在 CPU 上运行这些方法。 令人惊讶的是,上面的gather 解决方案是最快的:

    n_cols = 10000
    n_rows = 100
    shifts = torch.randint(-100,100,size=[n_rows,1])
    data = torch.arange(n_rows*n_cols).reshape(n_rows,n_cols)
    npdata = np.arange(n_rows*n_cols).reshape(n_rows,n_cols)
    npshifts = shifts.numpy()
    %timeit roll_by_gather(data,1,shifts)
    %timeit pcustom_roll(data,shifts)
    %timeit custom_roll(npdata,npshifts)
    >> 2.41 ms ± 68.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    >> 90.4 ms ± 882 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    >> 247 ms ± 6.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    在 GPU 上运行代码显示类似的结果:

    %timeit roll_by_gather(data,shifts)
    %timeit pcustom_roll(data,shifts)
    131 µs ± 6.79 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    3.29 ms ± 46.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    

    注意roll_by_gather方法中需要torch.arange(...,device='cuda:0')

    【讨论】:

      【解决方案2】:

      让我们定义一些名称:

      import torch
      
      mat = torch.Tensor(
      [[1,2,3],
       [4,5,6],
       [7,8,9]])
      
      indices = torch.LongTensor([0, 1, 2]) # Could also use arange in this specific scenario
      

      首先,你可以制作一个类似的张量

      [[0, 0, 0],
       [1, 1, 1],
       [2, 2, 2]]
      

      使用

      arange1 = torch.arange(3).view((3, 1)).repeat((1, 3))
      

      现在,让我们制作目标索引的张量

      [[0, 2, 1],
       [1, 0, 2],
       [2, 1, 0]]
      

      arange2 = (arange1 - indices) % 3
      

      最后,我们得到预期的输出

      torch.gather(mat, 0, arange2)
      

      【讨论】:

        猜你喜欢
        • 2020-04-30
        • 2021-02-06
        • 1970-01-01
        • 2012-05-30
        • 2019-07-09
        • 2019-09-30
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多