【问题标题】:How can I resize a PyTorch tensor with a sliding window?如何使用滑动窗口调整 PyTorch 张量的大小?
【发布时间】:2020-02-10 19:34:05
【问题描述】:

我有一个大小为:torch.Size([118160, 1]) 的张量。我想要做的是将它分成 n 个张量,每个张量有 100 个元素,一次滑动 50 个元素。使用 PyTorch 实现这一目标的最佳方法是什么?

【问题讨论】:

  • 我可能会做[x[i:min(x.size(0),i+100)] for i in range(0,x.size(0),50)],但最后几个元素会短于 100。这种行为可以吗?
  • 将其添加为答案,以便我接受

标签: python pytorch tensor


【解决方案1】:

您可以使用 Pytorch 的展开 API。参考这个https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html

例子:

x = torch.arange(1., 20)
x.unfold(0,4,2)

tensor([[ 1.,  2.,  3.,  4.],  
        [ 3.,  4.,  5.,  6.],  
        [ 5.,  6.,  7.,  8.],  
        [ 7.,  8.,  9., 10.],  
        [ 9., 10., 11., 12.],  
        [11., 12., 13., 14.],  
        [13., 14., 15., 16.],  
        [15., 16., 17., 18.]])

【讨论】:

    【解决方案2】:

    一个可能的解决方案是:

    window_size = 100
    stride = 50
    splits = [x[i:min(x.size(0),i+window_size)] for i in range(0,x.size(0),stride)]
    

    但是,最后几个元素将比window_size 短。如果这是不希望的,您可以这样做:

    splits = [x[i:i+window_size] for i in range(0,x.size(0)-window_size+1,stride)]
    

    编辑:

    更易读的解决方案:

    # if keep_short_tails is set to True, the slices shorter than window_size at the end of the result will be kept 
    def window_split(x, window_size=100, stride=50, keep_short_tails=True):
      length = x.size(0)
      splits = []
    
      if keep_short_tails:
        for slice_start in range(0, length, stride):
          slice_end = min(length, slice_start + window_size)
          splits.append(x[slice_start:slice_end])
      else:
        for slice_start in range(0, length - window_size + 1, stride):
          slice_end = slice_start + window_size
          splits.append(x[slice_start:slice_end])
    
      return splits
    

    【讨论】:

    • 无论如何要让它不那么 Pythonic 并因此更具可读性?
    • 对不起。我倾向于以牺牲可读性为代价来享受写单行文字的乐趣。我添加了一个更可重用(希望更易读)的函数。
    【解决方案3】:

    您的元素数量需要被 100 整除。如果不是这种情况,您可以使用填充进行调整。

    您可以先对原始列表进行拆分。 然后对列表进行拆分,从原始列表中删除前 50 个元素。 如果您想保留原始顺序,则可以从 A 和 B 中交替顺序采样。

    A = yourtensor
    B = yourtensor[50:] + torch.zeros(50,1)
    A_ = A.view(100,-1)
    B_ = B.view(100,-1)
    

    【讨论】:

      猜你喜欢
      • 2020-02-28
      • 2018-11-16
      • 2021-01-22
      • 2021-06-07
      • 1970-01-01
      • 2019-07-03
      • 1970-01-01
      • 1970-01-01
      • 2011-10-06
      相关资源
      最近更新 更多