【问题标题】:Product of PyTorch tensors along arbitrary axes à la NumPy's `tensordot`PyTorch 张量沿任意轴的乘积 à la NumPy 的 `tensordot`
【发布时间】:2026-01-11 07:00:01
【问题描述】:

NumPy 提供了非常有用的tensordot 函数。它允许您计算两个 ndarrays 沿任意轴(其大小匹配)的乘积。我很难在 PyTorch 中找到类似的东西。 mm 仅适用于二维数组,而matmul 有一些不受欢迎的广播属性。

我错过了什么吗?我真的打算使用mm 重塑阵列以模仿我想要的产品吗?

【问题讨论】:

  • @M.Deckers:怎么可能?它甚至不需要参数来指定要带产品的轴。
  • 目前不可用,但目前正在讨论here
  • @McLawrence:谢谢,这很清楚!

标签: matrix-multiplication pytorch dot-product


【解决方案1】:

最初的答案是完全正确的,但作为更新,Pytorch now supports tensordot 原生。与 numpy 相同的调用签名,但将 axes 更改为 dims

import torch
import numpy as np

a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640.  2838.] [ 2772.  2982.] [ 2904.  3126.]]

a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.tensordot(a, b, dims=([1,0],[0,1]))
print(c)
# tensor([[ 2640.,  2838.], [ 2772.,  2982.], [ 2904.,  3126.]], dtype=torch.float64)

【讨论】:

    【解决方案2】:

    正如@McLawrence 所述,目前正在讨论此功能 (issue thread)。

    同时,您可以考虑torch.einsum(),例如:

    import torch
    import numpy as np
    
    a = np.arange(36.).reshape(3,4,3)
    b = np.arange(24.).reshape(4,3,2)
    c = np.tensordot(a, b, axes=([1,0],[0,1]))
    print(c)
    # [[ 2640.  2838.] [ 2772.  2982.] [ 2904.  3126.]]
    
    a = torch.from_numpy(a)
    b = torch.from_numpy(b)
    c = torch.einsum("ijk,jil->kl", (a, b))
    print(c)
    # tensor([[ 2640.,  2838.], [ 2772.,  2982.], [ 2904.,  3126.]], dtype=torch.float64)
    

    【讨论】:

      最近更新 更多