【问题标题】:implement Attention mechanism in seq2seq Maluuba model在 seq2seq Maluuba 模型中实现注意力机制
【发布时间】:2018-07-24 18:14:41
【问题描述】:

您好,我正在尝试添加对简单 Maluuba/qgen-workshop seq2seq 模型的注意,但我无法弄清楚我应该将什么是正确的 batch_size 传递给我尝试过的初始状态:

  # Attention
  # attention_states: [batch_size, max_time, num_units]
attention_states = tf.transpose(encoder_outputs, [1, 0, 2])

  # Create an attention mechanism
attention_mechanism = tf.contrib.seq2seq.LuongAttention(
      encoder_cell.state_size, attention_states,
      memory_sequence_length=None)

decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
      decoder_cell, attention_mechanism,
      attention_layer_size=encoder_cell.state_size)

batch = next(training_data())
batch = collapse_documents(batch)

initial_state = decoder_cell.zero_state(batch["size"], tf.float32).clone(cell_state=encoder_state)

decoder = seq2seq.BasicDecoder(decoder_cell, helper, initial_state, output_layer=projection)

它给了我这个错误:

    InvalidArgumentError (see above for traceback): assertion failed: [When applying AttentionWrapper attention_wrapper_1: Non-matching batch sizes between the memory (encoder output) and the query (decoder output).

  Are you using the BeamSearchDecoder?  You may need to tile your memory input via the tf.contrib.seq2seq.tile_batch function with argument multiple=beam_width.] [Condition x == y did not hold element-wise:] [x (decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/x:0) = ] [99] [y (LuongAttention/strided_slice_1:0) = ] [29]
     [[Node: decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/Assert/Assert = Assert[T=[DT_STRING, DT_STRING, DT_STRING, DT_INT32, DT_STRING, DT_INT32], summarize=3, _device="/job:localhost/replica:0/task:0/cpu:0"](decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/All, decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/Assert/Assert/data_0, decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/Assert/Assert/data_1, decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/Assert/Assert/data_2, decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/x, decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/Assert/Assert/data_4, decoder/while/BasicDecoderStep/decoder/attention_wrapper/assert_equal/Equal/Enter)]]

【问题讨论】:

    标签: seq2seq


    【解决方案1】:

    我们目前有_MAX_BATCH_SIZE = 128,但每个批次的大小不同,因为我们希望确保一个故事的所有问题都在同一个批次中。所以每个批次都有一个'size' 键来表示它的大小。

    您似乎已经知道这一点。我认为问题是别的。也许encoder_cell.state_size 设置了较早批次的批次大小?

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2017-09-25
      • 2016-09-11
      • 1970-01-01
      • 2018-04-20
      • 2019-08-30
      • 2022-07-05
      相关资源
      最近更新 更多