strides 是在给定维度中从一个元素到下一个元素所需的步数(或跳转)。在计算机内存中,数据线性存储在连续的内存块中。我们查看的只是一个(重新)呈现。
让我们以张量为例来理解这一点:
# a 2D tensor
In [62]: tensor = torch.arange(1, 16).reshape(3, 5)
In [63]: tensor
Out[63]:
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15]])
有了这个张量,步幅是:
# get the strides
In [64]: tensor.stride()
Out[64]: (5, 1)
这个结果元组(5, 1) 说的是:
- 要沿着第 0th 维度/轴(Y 轴)遍历,假设我们要跳转从
1 到6,我们应该采取5 步(或跳跃)
- 要沿着第一个st 维度/轴(X 轴)遍历,假设我们要跳转从
7 到8,我们应该采取1 步(或跳跃)
元组中5 & 1 的顺序(或索引)表示维度/轴。您还可以将您想要跨度的维度作为参数传递:
# get stride for axis 0
In [65]: tensor.stride(0)
Out[65]: 5
# get stride for axis 1
In [66]: tensor.stride(1)
Out[66]: 1
有了这样的理解,我们可能要问为什么在创建张量时需要这个额外参数?答案是出于效率原因。 (我们如何才能最有效地存储/读取/访问(稀疏)张量中的元素?)。
使用稀疏张量(大多数元素只是零的张量),所以我们不想存储这些值。我们只存储非零值及其索引。使用所需的形状,其余的值可以用零填充,从而产生所需的稀疏张量。
如需进一步阅读,以下文章可能会有所帮助:
P.S:我猜torch.layout 文档中有一个错字
Strides 是一个整数列表 ...
tensor.stride() 返回的复合数据类型是元组,而不是列表。