虽然solution of Berriel 解决了这个特定问题,但我认为添加一些解释可能会帮助大家了解这里使用的技巧,以便它可以适应(m)任何其他维度。
让我们从检查输入张量x的形状开始:
In [58]: x.shape
Out[58]: torch.Size([3, 2, 2])
所以,我们有一个形状为 (3, 2, 2) 的 3D 张量。现在,根据 OP 的问题,我们需要计算张量中沿 1st 和 2nd 维度的值的maximum。在撰写本文时,torch.max() 的 dim 参数仅支持 int。所以,我们不能使用元组。因此,我们将使用以下技巧,我将其称为,
Flatten & Max Trick:因为我们想要在 1st 和 2nd 维度上计算 max,我们将进行展平这两个维度都归为一个维度,而第 0th 维度保持不变。这正是正在发生的事情:
In [61]: x.flatten().reshape(x.shape[0], -1).shape
Out[61]: torch.Size([3, 4]) # 2*2 = 4
所以,现在我们将 3D 张量缩小为 2D 张量(即矩阵)。
In [62]: x.flatten().reshape(x.shape[0], -1)
Out[62]:
tensor([[-0.3000, -0.2926, -0.2705, -0.2632],
[-0.1821, -0.1747, -0.1526, -0.1453],
[-0.0642, -0.0568, -0.0347, -0.0274]])
现在,我们可以简单地将max 应用于第一个st 维度(即在这种情况下,第一个维度也是最后一个维度),因为展平的维度位于该维度中。
In [65]: x.flatten().reshape(x.shape[0], -1).max(dim=1) # or: `dim = -1`
Out[65]:
torch.return_types.max(
values=tensor([-0.2632, -0.1453, -0.0274]),
indices=tensor([3, 3, 3]))
我们在结果张量中得到 3 个值,因为矩阵中有 3 行。
现在,另一方面,如果您想在 0th 和 1st 维度上计算 max,您可以:
In [80]: x.flatten().reshape(-1, x.shape[-1]).shape
Out[80]: torch.Size([6, 2]) # 3*2 = 6
In [79]: x.flatten().reshape(-1, x.shape[-1])
Out[79]:
tensor([[-0.3000, -0.2926],
[-0.2705, -0.2632],
[-0.1821, -0.1747],
[-0.1526, -0.1453],
[-0.0642, -0.0568],
[-0.0347, -0.0274]])
现在,我们可以简单地将max 应用于第 0th 维度,因为这是我们展平的结果。 ((同样,从我们原来的 (3, 2, 2) 形状来看,在前 2 个维度取 max 之后,我们应该得到两个值作为结果。)
In [82]: x.flatten().reshape(-1, x.shape[-1]).max(dim=0)
Out[82]:
torch.return_types.max(
values=tensor([-0.0347, -0.0274]),
indices=tensor([5, 5]))
类似地,您可以将此方法应用于多维和其他缩减函数,例如min。
注意:我遵循基于 0 的维度 (0, 1, 2, 3, ...) 的术语只是为了与 PyTorch 的使用和代码保持一致。