【问题标题】:How can I calculate all cross-terms in pytorch?如何计算 pytorch 中的所有交叉项?
【发布时间】:2021-12-14 14:02:41
【问题描述】:

我想计算矩阵中每个向量的所有交叉项。 例如,考虑以下矩阵:

X = tensor([[1, 2, 3],
            [4, 5, 6]]),

我想获得这个矩阵中每个向量的所有交叉项:

Y = [[1*1, 1*2, 1*3, 2*2, 2*3, 3*3],
     [4*4, 4*5, 4*6, 5*5, 5*6, 6*6]].
  = [[1, 2, 3, 4, 6, 9],
     [16, 20, 24, 25, 30, 36]].

即这是向量元素的所有组合值 我相信这可以使用 torch.combinations 来计算; 但是,torch.combinations 不提供批处理实现 我无法在 pytorch 中产生上述结果。

如何计算 pytorch 中的所有交叉项?

【问题讨论】:

    标签: pytorch


    【解决方案1】:

    您可以堆叠组合的乘积并替换该矩阵中的每一行

    >>> torch.stack(tuple(torch.prod(torch.combinations(data[i],with_replacement=True),1) for i in range(data.shape[0])),0)
    >>> tensor([[ 1,  2,  3,  4,  6,  9],
            [16, 20, 24, 25, 30, 36]])
    

    【讨论】:

    • 非常感谢您的回答,您的代码可以工作,但使用 for 循环处理张量数据。我想将张量数据作为批处理处理,但 torch.combination 不是基于批处理的实现。这是一个困难的实现吗? github.com/pytorch/pytorch/issues/40375
    猜你喜欢
    • 2020-05-26
    • 2020-12-23
    • 1970-01-01
    • 2022-01-09
    • 2020-06-04
    • 1970-01-01
    • 2019-03-05
    • 1970-01-01
    相关资源
    最近更新 更多