【问题标题】:How to understand masked multi-head attention in transformer如何理解transformer中的masked multi-head attention
【发布时间】:2020-01-27 08:06:49
【问题描述】:

我目前正在研究transformer的代码,但是我无法理解decoder的masked multi-head。论文说是为了不让你看到生成词,但是如果生成词之后的词还没有生成,怎么能看到呢?

我尝试阅读转换器的代码(链接:https://github.com/Kyubyong/transformer)。代码实现的掩码如下所示。它使用下三角矩阵来掩盖,我不明白为什么。

padding_num = -2 ** 32 + 1
diag_vals = tf.ones_like(inputs[0, :, :])  # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense()  # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1])  # (N, T_q, T_k)
paddings = tf.ones_like(masks) * padding_num
outputs = tf.where(tf.equal(masks, 0), paddings, inputs)

【问题讨论】:

    标签: tensorflow deep-learning transformer attention-model


    【解决方案1】:

    阅读Transformer paper 后,我也有同样的问题。我在互联网上没有找到该问题的完整详细答案,所以我将尝试解释我对 Masked Multi-Head Attention 的理解。

    简短的回答是——我们需要掩蔽来使训练并行。并且并行化很好,因为它可以让模型训练得更快。

    这是一个解释这个想法的例子。假设我们训练将“我爱你”翻译成德语。编码器以并行模式工作 - 它可以在恒定步数内生成输入序列(“我爱你”)的矢量表示(即步数不取决于输入序列的长度)。

    假设编码器生成数字11, 12, 13 作为输入序列的向量表示。实际上,这些向量会更长,但为简单起见,我们使用较短的向量。同样为简单起见,我们忽略了服务令牌,例如 - 序列的开头、 - 序列的结尾等。

    在训练期间,我们知道翻译应该是“Ich liebe dich”(我们在训练期间总是知道预期的输出)。假设“Ich liebe dich”词的预期向量表示是21, 22, 23

    如果我们以顺序模式训练解码器,它看起来就像是循环神经网络的训练。将执行以下顺序步骤:

    • 顺序操作 #1。输入:11, 12, 13
      • 试图预测21
      • 预测的输出不会完全是21,假设是21.1
    • 顺序操作 #2。输入:11, 12, 13,还有 21.1 作为之前的输出。
      • 试图预测22
      • 预测的输出不会完全是22,假设是22.3
    • 顺序操作 #3。输入11, 12, 13,也输入22.3作为前一个输出。
      • 试图预测23
      • 预测的输出不会完全是23,假设是23.5

    这意味着我们需要进行 3 个顺序操作(一般情况下 - 每个输入一个顺序操作)。此外,我们将在每次下一次迭代中累积错误。此外,我们不使用注意力,因为我们只查看单个先前的输出。

    由于我们实际上知道预期的输出,我们可以调整流程并使其并行。无需等待上一步输出。

    • 并行操作#A。输入:11, 12, 13
      • 试图预测21
    • 并行操作#B。输入:11, 12, 13,还有21
      • 试图预测22
    • 并行操作#C。输入:11, 12, 13,还有21, 22
      • 试图预测23

    该算法可以并行执行,也不会累积错误。该算法使用注意力(即查看所有先前的输入),因此在进行预测时需要考虑更多有关上下文的信息。

    这里是我们需要掩蔽的地方。训练算法知道整个预期输出 (21, 22, 23)。它为每个并行操作隐藏(屏蔽)这个已知输出序列的一部分。

    • 当它执行 #A - 它隐藏(屏蔽)整个输出。
    • 当它执行 #B 时 - 它会隐藏第 2 和第 3 输出。
    • 当它执行 #C - 它隐藏第三个输出。

    屏蔽本身的实现如下(来自original paper):

    我们通过屏蔽在缩放的点积注意力中实现这一点 输出(设置为 -∞)softmax 输入中的所有值 对应非法连接

    注意:在推理(非训练)期间,解码器以顺序(非并行)模式工作,因为它最初不知道输出序列。但它与 RNN 方法不同,因为 Transformer 推理仍然使用自我注意并查看所有先前的输出(但不仅仅是前一个)。

    注意 2:我在某些材料中看到,掩蔽可用于非翻译应用的不同方式。例如,对于语言建模,掩码可用于从输入句子中隐藏一些单词,并且模型将在训练期间尝试使用其他非掩码单词(即学习理解上下文)来预测它们。

    【讨论】:

    • 我推荐这个article,你的解释和文章很有帮助。
    • +1 是一个非常有用的例子。有了这个,我对 src_mark 就很清楚了。尽管如此, src_key_padding_mask 对我来说还是有点模糊。是忽略单个序列末尾的几个填充标记吗?
    猜你喜欢
    • 1970-01-01
    • 2021-02-24
    • 2021-03-28
    • 2021-07-24
    • 1970-01-01
    • 2019-06-27
    • 1970-01-01
    • 2021-12-07
    • 1970-01-01
    相关资源
    最近更新 更多