在最简单的情况下,torch.nn.functional.embedding_bag 在概念上是一个两步过程。第一步是创建一个嵌入,第二步是减少(总和/平均/最大值,根据“模式”参数)跨维度 0 的嵌入输出。因此,您可以通过调用 @ 得到与 embedding_bag 相同的结果987654322@,然后是torch.sum/mean/max。在以下示例中,embedding_bag_res 和 embedding_mean_res 相等。
>>> weight = torch.randn(3, 4)
>>> weight
tensor([[ 0.3987, 1.6173, 0.4912, 1.5001],
[ 0.2418, 1.5810, -1.3191, 0.0081],
[ 0.0931, 0.4102, 0.3003, 0.2288]])
>>> indices = torch.tensor([2, 1])
>>> embedding_res = torch.nn.functional.embedding(indices, weight)
>>> embedding_res
tensor([[ 0.0931, 0.4102, 0.3003, 0.2288],
[ 0.2418, 1.5810, -1.3191, 0.0081]])
>>> embedding_mean_res = embedding_res.mean(dim=0, keepdim=True)
>>> embedding_mean_res
tensor([[ 0.1674, 0.9956, -0.5094, 0.1185]])
>>> embedding_bag_res = torch.nn.functional.embedding_bag(indices, weight, torch.tensor([0]), mode='mean')
>>> embedding_bag_res
tensor([[ 0.1674, 0.9956, -0.5094, 0.1185]])
但是,概念上的两步流程并未反映其实际实施方式。由于embedding_bag 不需要返回中间结果,它实际上并没有为嵌入生成张量对象。它直接计算减少量,根据input 参数中的索引从weight 参数中提取适当的数据。避免创建嵌入张量可以提高性能。
所以你的问题的答案(如果我理解正确的话)
这只是意味着在调用 embedding_bag 时不会产生额外的 Tensor,但从系统的角度来看,所有中间行向量已经被提取到处理器中以用于计算最终的 Tensor?
是的。