有多种方法可以重塑 PyTorch 张量。您可以将这些方法应用于任何维度的张量。
让我们从二维2 x 3 张量开始:
x = torch.Tensor(2, 3)
print(x.shape)
# torch.Size([2, 3])
为了给这个问题增加一些鲁棒性,让我们重塑 2 x 3 张量,在前面添加一个新维度,在中间添加另一个维度,生成一个 1 x 2 x 1 x 3 张量。
方法一:用None添加维度
在任何你想要的地方使用 NumPy 风格的 insertion of None (aka np.newaxis) to add dimensions。见here。
print(x.shape)
# torch.Size([2, 3])
y = x[None, :, None, :] # Add new dimensions at positions 0 and 2.
print(y.shape)
# torch.Size([1, 2, 1, 3])
方法2:解压
使用torch.Tensor.unsqueeze(i)(又名torch.unsqueeze(tensor, i) 或就地版本unsqueeze_())在第i 个维度添加一个新维度。返回的张量与原始张量共享相同的数据。在这个例子中,我们可以使用unqueeze() 两次来添加两个新维度。
print(x.shape)
# torch.Size([2, 3])
# Use unsqueeze twice.
y = x.unsqueeze(0) # Add new dimension at position 0
print(y.shape)
# torch.Size([1, 2, 3])
y = y.unsqueeze(2) # Add new dimension at position 2
print(y.shape)
# torch.Size([1, 2, 1, 3])
在 PyTorch 的实践中,adding an extra dimension for the batch 可能很重要,所以你可能经常看到unsqueeze(0)。
方法三:查看
使用torch.Tensor.view(*shape) 指定所有维度。返回的张量与原始张量共享相同的数据。
print(x.shape)
# torch.Size([2, 3])
y = x.view(1, 2, 1, 3)
print(y.shape)
# torch.Size([1, 2, 1, 3])
方法四:重塑
使用torch.Tensor.reshape(*shape)(又名torch.reshape(tensor, shapetuple))指定所有维度。如果原始数据是连续的并且具有相同的步幅,则返回的张量将是输入的视图(共享相同的数据),否则将是副本。此函数类似于 NumPy reshape() 函数,因为它允许您定义所有维度并可以返回视图或副本。
print(x.shape)
# torch.Size([2, 3])
y = x.reshape(1, 2, 1, 3)
print(y.shape)
# torch.Size([1, 2, 1, 3])
此外,来自 O'Reilly 2019 年出版的书 Programming PyTorch for Deep Learning,作者写道:
现在您可能想知道view() 和reshape() 之间有什么区别。答案是view() 作为原始张量上的视图运行,因此如果基础数据发生更改,视图也会更改(反之亦然)。但是,如果所需的视图不连续,view() 可能会抛出错误;也就是说,如果从头开始创建所需形状的新张量,它不会共享相同的内存块。如果发生这种情况,您必须先致电tensor.contiguous(),然后才能使用view()。然而,reshape() 在幕后完成了所有这些工作,所以总的来说,我建议使用reshape() 而不是view()。
方法五:调整大小_
使用就地函数torch.Tensor.resize_(*sizes) 修改原始张量。文档指出:
警告。这是一种低级方法。存储被重新解释为 C 连续,忽略当前步幅(除非目标大小等于当前大小,在这种情况下张量保持不变)。在大多数情况下,您将希望使用检查连续性的view() 或在需要时复制数据的reshape()。要使用自定义步幅就地更改大小,请参阅set_()。
print(x.shape)
# torch.Size([2, 3])
x.resize_(1, 2, 1, 3)
print(x.shape)
# torch.Size([1, 2, 1, 3])
我的观察
如果您只想添加一个维度(例如,为批次添加第 0 个维度),请使用unsqueeze(0)。如果您想完全改变维度,请使用reshape()。
另见:
What's the difference between reshape and view in pytorch?
What is the difference between view() and unsqueeze()?
In PyTorch 0.4, is it recommended to use reshape than view when it is possible?