简单地说,torch.Tensor.view() 受numpy.ndarray.reshape() 或numpy.reshape() 启发,创建张量的新视图,只要新形状与原始形状兼容张量。
让我们通过一个具体的例子来详细理解这一点。
In [43]: t = torch.arange(18)
In [44]: t
Out[44]:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])
使用形状为(18,) 的张量t,可以仅为以下形状创建新的视图:
(1, 18) 或等效的 (1, -1) 或 (-1, 18)
(2, 9) 或等同于 (2, -1) 或 (-1, 9)
(3, 6) 或等同于 (3, -1) 或 (-1, 6)
(6, 3) 或等效的 (6, -1) 或 (-1, 3)
(9, 2) 或等效 (9, -1) 或 (-1, 2)
(18, 1) 或等效 @987654352 @ 或 (-1, 1)
正如我们已经从上面的形状元组中观察到的那样,形状元组的元素相乘(例如2*9、3*6 等)必须始终等于总数原始张量中的元素(在我们的示例中为18)。
要注意的另一件事是,我们在每个形状元组的一个位置使用了-1。通过使用-1,我们懒得自己进行计算,而是将任务委托给 PyTorch 在创建新 视图 时为形状计算该值。需要注意的重要一点是,我们可以仅在形状元组中使用单个 -1。其余值应由我们明确提供。否则 PyTorch 将通过抛出 RuntimeError 来抱怨:
RuntimeError: 只能推断出一个维度
因此,对于上述所有形状,PyTorch 将始终返回原始张量 t 的新视图。这基本上意味着它只是为请求的每个新视图更改张量的步幅信息。
以下是一些示例,说明张量的步幅如何随着每个新的视图而改变。
# stride of our original tensor `t`
In [53]: t.stride()
Out[53]: (1,)
现在,我们将看到新视图的进步:
# shape (1, 18)
In [54]: t1 = t.view(1, -1)
# stride tensor `t1` with shape (1, 18)
In [55]: t1.stride()
Out[55]: (18, 1)
# shape (2, 9)
In [56]: t2 = t.view(2, -1)
# stride of tensor `t2` with shape (2, 9)
In [57]: t2.stride()
Out[57]: (9, 1)
# shape (3, 6)
In [59]: t3 = t.view(3, -1)
# stride of tensor `t3` with shape (3, 6)
In [60]: t3.stride()
Out[60]: (6, 1)
# shape (6, 3)
In [62]: t4 = t.view(6,-1)
# stride of tensor `t4` with shape (6, 3)
In [63]: t4.stride()
Out[63]: (3, 1)
# shape (9, 2)
In [65]: t5 = t.view(9, -1)
# stride of tensor `t5` with shape (9, 2)
In [66]: t5.stride()
Out[66]: (2, 1)
# shape (18, 1)
In [68]: t6 = t.view(18, -1)
# stride of tensor `t6` with shape (18, 1)
In [69]: t6.stride()
Out[69]: (1, 1)
这就是view() 函数的魔力。只要新 view 的形状与原始形状兼容,它只会更改每个新 views 的(原始)张量的步幅。 p>
从 strides 元组中可能观察到的另一件有趣的事情是,第 0th 位置的元素的值等于第 1st位置的元素的值> 形状元组的位置。
In [74]: t3.shape
Out[74]: torch.Size([3, 6])
|
In [75]: t3.stride() |
Out[75]: (6, 1) |
|_____________|
这是因为:
In [76]: t3
Out[76]:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17]])
步幅(6, 1) 表示要沿着第 0th 维度从一个元素移动到下一个元素,我们必须跳转 或走 6 步。 (即从0 到6,需要6 个步骤。)但是要从第一个st 维度中的一个元素到下一个元素,我们只需要一个步骤(例如从2 到3)。
因此,步长信息是如何从内存中访问元素以执行计算的核心。
这个函数将返回一个 view 并且与使用 torch.Tensor.view() 完全相同,只要新形状与原始张量的形状兼容。否则,它将返回一个副本。
但是,torch.reshape() 的注释警告说:
连续输入和具有兼容步幅的输入可以在不复制的情况下进行重新整形,但不应依赖于复制与查看行为。