您要查找的是来自PyTorch 和Numpy 的Tensordot 命令
由于您想计算沿 N 的点积,即 x1 的维度 1 和 x2 张量的维度 1,您需要通过提供 ([1], [1]) 沿两个张量的第一轴执行收缩到 Tensordot 中的 dims arg。这意味着 Torch 将分别在指定的 x1 轴 1 和指定的 x2 轴 1 上求和 x1 和 x2 元素的乘积。提供给dims 的参数很混乱,这里有一个有用的线程来帮助理解如何使用Tensordothere
x1 = torch.arange(6.).reshape(2,3)
>>> tensor([[0., 1., 2.],
[3., 4., 5.]])
# x1 is Tensor of shape (2,3)
x2 = torch.arange(9.).reshape(3,3)
>>> tensor([[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]])
# x2 is Tensor of shape (3,3)
x = torch.tensordot(x1, x2, dims=([1],[1]))
>>> tensor([[ 5., 14., 23.],
[14., 50., 86.]])
# x is Tensor of shape (2,3)