【问题标题】:How to build a attention model with keras?如何用 keras 建立注意力模型?
【发布时间】:2019-11-18 16:21:22
【问题描述】:

我正在尝试理解注意力模型并自己构建一个。经过多次搜索,我遇到了this website,它有一个用 keras 编码的 atteniton 模型,而且看起来也很简单。但是当我尝试在我的机器中构建相同的模型时,它给出了多个参数错误。该错误是由于类Attention 中传递的参数不匹配造成的。在网站的注意力类中,它要求一个参数,但它用两个参数启动注意力对象。

import tensorflow as tf

max_len = 200
rnn_cell_size = 128
vocab_size=250

class Attention(tf.keras.Model):
    def __init__(self, units):
        super(Attention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)
    def call(self, features, hidden):
        hidden_with_time_axis = tf.expand_dims(hidden, 1)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
        attention_weights = tf.nn.softmax(self.V(score), axis=1)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, attention_weights

sequence_input = tf.keras.layers.Input(shape=(max_len,), dtype='int32')

embedded_sequences = tf.keras.layers.Embedding(vocab_size, 128, input_length=max_len)(sequence_input)

lstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM
                                     (rnn_cell_size,
                                      dropout=0.3,
                                      return_sequences=True,
                                      return_state=True,
                                      recurrent_activation='relu',
                                      recurrent_initializer='glorot_uniform'), name="bi_lstm_0")(embedded_sequences)

lstm, forward_h, forward_c, backward_h, backward_c = tf.keras.layers.Bidirectional \
    (tf.keras.layers.LSTM
     (rnn_cell_size,
      dropout=0.2,
      return_sequences=True,
      return_state=True,
      recurrent_activation='relu',
      recurrent_initializer='glorot_uniform'))(lstm)

state_h = tf.keras.layers.Concatenate()([forward_h, backward_h])
state_c = tf.keras.layers.Concatenate()([forward_c, backward_c])

#  PROBLEM IN THIS LINE
context_vector, attention_weights = Attention(lstm, state_h)

output = keras.layers.Dense(1, activation='sigmoid')(context_vector)

model = keras.Model(inputs=sequence_input, outputs=output)

# summarize layers
print(model.summary())

我怎样才能使这个模型工作?

【问题讨论】:

标签: python tensorflow keras deep-learning attention-model


【解决方案1】:

你初始化attention layer和传递参数的方式有问题。你应该在这个地方指定attention layer单元的个数并修改传入参数的方式:

context_vector, attention_weights = Attention(32)(lstm, state_h)

结果:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 200)          0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 200, 128)     32000       input_1[0][0]                    
__________________________________________________________________________________________________
bi_lstm_0 (Bidirectional)       [(None, 200, 256), ( 263168      embedding[0][0]                  
__________________________________________________________________________________________________
bidirectional (Bidirectional)   [(None, 200, 256), ( 394240      bi_lstm_0[0][0]                  
                                                                 bi_lstm_0[0][1]                  
                                                                 bi_lstm_0[0][2]                  
                                                                 bi_lstm_0[0][3]                  
                                                                 bi_lstm_0[0][4]                  
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 256)          0           bidirectional[0][1]              
                                                                 bidirectional[0][3]              
__________________________________________________________________________________________________
attention (Attention)           [(None, 256), (None, 16481       bidirectional[0][0]              
                                                                 concatenate[0][0]                
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 1)            257         attention[0][0]                  
==================================================================================================
Total params: 706,146
Trainable params: 706,146
Non-trainable params: 0
__________________________________________________________________________________________________
None

【讨论】:

    【解决方案2】:

    注意力层现在是 Tensorflow(2.1) 的 Keras API 的一部分。但它输出的张量与您的“查询”张量大小相同。

    这是如何使用 Luong 式的注意力:

    query_attention = tf.keras.layers.Attention()([query, value])
    

    还有 Bahdanau 式的关注:

    query_attention = tf.keras.layers.AdditiveAttention()([query, value])
    

    改编版:

    attention_weights = tf.keras.layers.Attention()([lstm, state_h])

    查看原网站了解更多信息:https://www.tensorflow.org/api_docs/python/tf/keras/layers/Attention https://www.tensorflow.org/api_docs/python/tf/keras/layers/AdditiveAttention

    【讨论】:

    • 您能否根据此特定 OP 的问题澄清“查询”和“价值”? OP 想要将“lstm”和“state_h”传递给注意力层。
    • queryvalue 应该是什么?有什么例子吗?
    • @Yahya 他们需要是时序数据格式 [batch, time, feature] 的 TensorFlow 张量。我希望这是你所要求的。
    【解决方案3】:

    为了回答 Arman 的特定查询 - 这些库使用 2018 年后的查询、值和键语义。要将语义映射回 Bahdanau 或 Luong 的论文,您可以将“查询”视为最后一个解码器隐藏状态。 “值”将是编码器输出的集合——编码器的所有隐藏状态。 “查询”“参与”所有“值”。

    无论您使用的是哪个版本的代码或库,请始终注意“查询”将在时间轴上展开,以便为随后的添加做好准备。这个值(正在扩展)将始终是 RNN 的最后一个隐藏状态。另一个值将始终是需要关注的值 - 编码器端的所有隐藏状态。无论您使用的是什么库或代码,都可以对代码进行这种简单的检查以确定“查询”和“值”映射到什么。

    您可以参考https://towardsdatascience.com/create-your-own-custom-attention-layer-understand-all-flavours-2201b5e8be9e,用不到6行代码编写自己的自定义注意力层

    【讨论】:

      猜你喜欢
      • 2016-09-11
      • 1970-01-01
      • 2019-02-08
      • 2018-04-20
      • 2018-06-21
      • 2021-02-07
      • 1970-01-01
      • 1970-01-01
      • 2020-08-29
      相关资源
      最近更新 更多