【发布时间】:2021-06-21 20:43:49
【问题描述】:
我正在尝试在图神经网络中实现消息传递。在每个图中,都有边和节点,节点到边的更新实现如下: 其中方括号表示拼接操作,下标为索引,上标为时间索引。
所以我试图连接 3 个维度矩阵:AxN、AxBxM 和 BxN。并且得到的串联是维度:AxBx(2N+M)。因此,结果矩阵的每个 (i,j) 都是第一个矩阵的第 i 行、第三个矩阵的第 j 行和第二个矩阵的第 (i,j) 个元素的串联。我设法在一个双 for 循环中实现了这一点,如下所示:
edge_in = torch.zeros(a, b, m + 2 * n)
edge_in = edge_in.cuda()
for i in range(a):
for j in range(b):
edge_in[i,j] = torch.cat((nodes_a_embeds[i], edge_embeds[i,j], nodes_b_embeds[j]))
但是,这非常慢。这是否以任何方式可矢量化?我试图想出一个解决方案,然后我在网上寻找一个解决方案,但无法将它矢量化。谢谢。
编辑:根据要求的数字示例:
第一个矩阵:5x3 第二个矩阵:5x4x2 第三个矩阵:4x3
那么输出应该是 5x4x8。让我们称我们的输出矩阵为 R。
然后 R(1,2) = 连接(First(1),Second(1,2),Third(2))。
【问题讨论】:
-
您能否添加三个数组的示例输入和预期输出示例?
-
加了个例子,现在是不是更清楚了?
标签: python matrix concatenation vectorization