【问题标题】:Tensorflow seq2seq Model, loss value is good, but prediction is falseTensorflow seq2seq 模型,损失值不错,但预测错误
【发布时间】:2020-04-12 09:51:25
【问题描述】:

我正在尝试训练 Seq2Seq 模型。它应该将句子从 source_vocabulary 翻译成 target_vocabulary 中的句子。

损失值为 0.28,但网络不预测目标词汇表中的单词。相反,网络的预测是负值。我不确定,如果代码中的某些内容是错误的,或者词汇量是否太大,或者我没有接受足够的训练。我用大约 270 000 个句子的数据集的一部分进行训练。即使损失值降低,我也不知道网络是否在学习什么。

def encDecEmb():
    batch_size = 32
    seq_length = 40
    vocab_size = 289415
    epochs = 10
    embedding_size = 300
    hidden_units = 20
    learning_rate = 0.001

    #shape = batch_size, seq_length
    encoder_inputs = tf.placeholder(
        tf.int32, shape=(None, None), name='encoder_inputs')
    decoder_inputs = tf.placeholder(
        tf.int32, shape=(None, None), name='decoder_inputs')

    sequence_length = tf.placeholder(tf.int32, [None], name='sequence_length')

    # Embedding
    embedding = tf.get_variable("embedding", [vocab_size, embedding_size])
    encoder_embedding = tf.nn.embedding_lookup(embedding, encoder_inputs)
    decoder_embedding = tf.nn.embedding_lookup(embedding, decoder_inputs)

    # Encoder
    encoder_cell = tf.contrib.rnn.LSTMCell(hidden_units)
    encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(
        encoder_cell, encoder_embedding, dtype=tf.float32)

    projection_layer = Dense(vocab_size, use_bias=False)
    helper = tf.contrib.seq2seq.TrainingHelper(decoder_embedding,
                                               sequence_length=sequence_length)

    # Decoder
    decoder_cell = tf.contrib.rnn.LSTMCell(hidden_units)
    decoder = tf.contrib.seq2seq.BasicDecoder(
        cell=decoder_cell, initial_state=encoder_final_state, helper=helper,
        output_layer=projection_layer)

    decoder_outputs, _final_state, _final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(
        decoder)

    logits = decoder_outputs.rnn_output
    training_logits = tf.identity(decoder_outputs.rnn_output, name='logits')
    target_labels = tf.placeholder(tf.int32, shape=(batch_size, seq_length))

    weight_mask = tf.sequence_mask([i for i in range(
        batch_size)], seq_length, dtype=tf.float32, name="weight_mask")

    # loss
    loss = tf.contrib.seq2seq.sequence_loss(
        logits=training_logits, targets=decoder_inputs, weights=weight_mask)

    #AdamOptimizer, Gradientclipping
    optimizer = tf.train.AdamOptimizer(learning_rate)
    gradients = optimizer.compute_gradients(loss)
    capped_gradients = [(tf.clip_by_value(grad, -1., 1.), var)
                        for grad, var in gradients if grad is not None]
    train_opt = optimizer.apply_gradients(capped_gradients)

    # read files
    x = readCSV_to_int("./xTest.csv")
    y = readCSV_to_int("./yTest.csv")

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()

    for epoch in range(epochs):
        for batch, (inputs, targets) in enumerate(get_batches(x, y, batch_size)):
            _, loss_value = sess.run([train_opt, loss],
                                     feed_dict={encoder_inputs: inputs, decoder_inputs: targets, target_labels: targets,
                                                sequence_length: [len(inputs[0])] * batch_size})
            print('Epoch{:>3}  Batch {:>4}/{}  Loss {:>6.4f}'.format(epoch, batch, (len(x) // batch_size),
                                                                     loss_value))

    saver.save(sess, './model_on_testset')
    print("Model Trained and Saved")

【问题讨论】:

    标签: lstm tensorflow2.0 loss-function seq2seq encoder-decoder


    【解决方案1】:

    没关系。 logits 是负数,它们是通过 softmax 标准化之前网络的输出,即,这些数字在标准化之前被取幂(因此变成小的正数)。

    您选择的架构:普通的序列到序列模型并不是最容易训练的模型。鉴于您只有有限数量的训练数据(270k),我会使用:

    1. 双向编码器;和
    2. 具有注意力机制的解码器。

    Attention 简化了从解码器(发生实际损失计算的地方)到解码器的梯度流,因为梯度不仅通过编码器最终状态流动,而且通过对所有编码器状态的关注。

    还有其他影响性能的因素:例如你如何分割你的数据(无论你使用单词还是子词)以及你使用的词汇量有多大。

    无论如何,确定模型是否正在学习的最佳方法是尝试从模型中解码目标句子。

    【讨论】:

      猜你喜欢
      • 2020-11-08
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2018-09-03
      • 2021-10-17
      • 2019-05-24
      • 2019-04-17
      • 1970-01-01
      相关资源
      最近更新 更多