【问题标题】:How to vectorize attention operation and avoid for-loop如何向量化注意力操作并避免for循环
【发布时间】:2021-11-03 14:52:56
【问题描述】:

我是 Attention 的新手,我在我的模型中的 forward() 函数的以下代码中使用 python for 循环实现了一种这样的机制。

基本上,我有一个项目嵌入层,通过它我可以嵌入一个项目和一系列其他项目,这些项目我通过注意力权重加权。为了获得注意力权重,我使用了一个子网络 (nn.Sequencial(...)),它将一对两个项目嵌入作为输入,并在回归中输出一个分数。然后将所有分数进行 softmax 化并用作注意力权重。

def forward(self, input_features, ...):
    ...
    """ B = batch size, I = number of items for attention, E = embedding size """
    ...
    
    # get embeddings from input features for current batch
    embeddings = self.embedding_layer(input_features)         # (B, E)
    other_embeddings = self.embedding_layer(other_features)   # (I, E)

    # attention between pairs of embeddings
    attention_scores = torch.zeros((B, I))             # (B, I)
    for i in range(I):
        # repeat batch-size times for i-th embedding
        repeated_other_embedding = other_embeddings[i].view(1, -1).repeat(B, 1)   # (B, E)

        # concat pairs of embeddings to form input to attention network   
        item_emb_pairs = torch.cat((embeddings.detach(), repeated_other_embedding.detach()), dim=1)

        # pass batch through attention network
        attention_scores[:, [i]] = self.AttentionNet(item_emb_pairs)

    # pass through softmax
    attention_scores = F.softmax(attention_scores, dim=1)   # (B, I)

    ...

如何避免 python for 循环,我怀疑是什么让训练减慢了这么多?我可以以某种方式在 self.AttentionNet() 中传递维度矩阵 (I, B, 2*E) 吗?

【问题讨论】:

    标签: python pytorch vectorization attention-model


    【解决方案1】:

    你可以使用下面的sn-p。

    embeddings = self.embedding_layer(input_features)         # (B, E)
    other_embeddings = self.embedding_layer(other_features)   # (I, E)
    
    embs = embeddings.unsqueeze(1).repeat(1, I, 1)              # (B, I, E)
    other_embs = other_embeddings.unsqueeze(0).repeat(B, 1, 1)  # (B, I, E)
    
    concatenated_embeddings = torch.cat((embs, other_embs), dim=2)  # (B, I, 2*E)
    
    attention_scores = F.softmax(self.AttentionNet(concatenated_embeddings))    #(B, I)
    

    您可能需要在 self.AttentionNet 中进行一些更改,因为在这种情况下,您将 Batch 大小为 B 的输入张量提供给注意力网络。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2015-05-22
      • 1970-01-01
      • 2020-12-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多