【问题标题】:How to improve LSTM model predictions and accuracy?如何提高 LSTM 模型的预测和准确性?
【发布时间】:2020-09-26 04:47:15
【问题描述】:

在使用 gensim 创建预嵌入层后,对于 4600 条记录,我的 val_accuracy 下降到 45%:-

model =  models.Sequential()
   
    model.add(Embedding(input_dim=MAX_NB_WORDS, output_dim=EMBEDDING_DIM, 
                         weights=[embedding_model],trainable=False,
                        input_length=seq_len,mask_zero=True))
    #model.add(SpatialDropout1D(0.2))
       
    
    #model.add(Embedding(vocabulary_size, 64))
    model.add(GRU(units=150, return_sequences=True))
    model.add(Dropout(0.4))
    model.add(LSTM(units=200,dropout=0.4))  
    #model.add(Dropout(0.8))
    #model.add(LSTM(100)) 
    #model.add(Dropout(0.4))
    #Bidirectional(tf.keras.layers.LSTM(embedding_dim))
    #model.add(LSTM(400,input_shape=(1117, 100),return_sequences=True))
    #model.add(Bidirectional(LSTM(128)))
    model.add(Dense(100, activation='relu'))
    #
    #model.add(Dropout(0.4))
    #model.add(Dense(200, activation='relu'))
    model.add(Dense(4, activation='softmax'))

    model.compile(loss='categorical_crossentropy', optimizer='rmsprop', 
                  metrics=['accuracy'])

型号:“sequential_4”

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_4 (Embedding)      (None, 50, 100)           2746300   
_________________________________________________________________
gru_4 (GRU)                  (None, 50, 150)           112950    
_________________________________________________________________
dropout_4 (Dropout)          (None, 50, 150)           0         
_________________________________________________________________
lstm_4 (LSTM)                (None, 200)               280800    
_________________________________________________________________
dense_7 (Dense)              (None, 100)               20100     
_________________________________________________________________
dense_8 (Dense)              (None, 4)                 404       
=================================================================
Total params: 3,160,554
Trainable params: 414,254
Non-trainable params: 2,746,300
_________________________________________________________________

完整代码位于 https://colab.research.google.com/drive/13N94kBKkHIX2TR5B_lETyuH1QTC5VuRf?usp=sharing

这对我很有帮助。因为我是深度学习的新手,我几乎尝试了所有我知道的东西。但现在都是空白。

【问题讨论】:

  • 你能详细说明这个问题吗?您从模型中得到什么样的预测?
  • 可能是您的模型过拟合。你检查过混淆矩阵吗?
  • @AbhinavGoyal 我在标记化后将 X 作为输入传递,结果我的模型预测句子将在 Y 中出现的类别。您可以在colab.research.google.com/drive/…查看代码
  • @Aka 是的,我已经检查了矩阵,但由于模型只给出了 30% 的准确率。矩阵信息对我没有帮助。看看代码colab.research.google.com/drive/…

标签: python tensorflow machine-learning keras lstm


【解决方案1】:

问题在于您的输入。您已用零填充输入序列,但尚未向模型提供此信息。因此,您的模型不会忽略零,这是它根本没有学习的原因。要解决此问题,请按如下方式更改嵌入层:

model.add(layers.Embedding(input_dim=vocab_size+1,
      output_dim=embedding_dim,
      mask_zero=True))

这将使您的模型忽略零填充并学习。用这个进行训练,我在 6 个 epoch 中得到了 100% 的训练准确率,尽管验证准确率不是那么好(大约 54%),这是预期的,因为你的训练数据只包含 32 个示例。更多关于嵌入层:https://keras.io/api/layers/core_layers/embedding/


由于您的数据集很小,模型很容易过度拟合训练数据,从而降低验证准确度。为了在一定程度上缓解这种情况,您可以尝试使用预训练的词嵌入,如 word2vec 或 GloVe,而不是训练自己的嵌入层。此外,尝试一些文本数据增强方法,例如使用模板创建人工数据或用同义词替换训练数据中的单词。您还可以尝试不同类型的层(例如用另一个 LSTM 替换 GRU),但我认为这在这里可能没有太大帮助,应该在尝试预训练嵌入和数据增强后考虑。

【讨论】:

  • 我已经尝试使用具有 1988 条记录的数据集。此外,增加了层中的单位并进行了一些修改,现在准确度达到了 88% 左右。我还能做些什么来将准确度提高到 95%+ ? colab.research.google.com/drive/…
  • 由于您的模型已经过拟合训练数据,我认为增加单元或隐藏层的数量可能会对性能产生不利影响。当您的模型不适合数据时,它会有所帮助。至于如何提高验证准确性,我已经用一些想法更新了我的答案。让我知道他们是否有帮助。
  • 我现在有 3640 条记录。对于嵌入矩阵,我尝试了使用“glove.6B.50d.txt”和“glove.6B.50d.txt”。还增加了层中的单位并进行了一些修改,但准确度仍然在 80% 左右。有模型文件夹大小限制,因为我不能将大小增加到 200mb 以上。因此,任何可以在不使用“Glove.5d”文件的情况下完成的建议。 colab.research.google.com/drive/….谢谢
  • 你必须明白机器学习不是魔法。为了提高模型的准确性,您必须尝试数据、预处理、模型和优化技术。即使在那之后,由于数据、计算资源或模型等的限制,您可能无法获得如此高的测试精度。我的建议是正确分析数据,设定现实的期望并根据该分析应用技术,而不是仅仅抛出您在任何随机 ML 模型中拥有的数据。无论如何,一切顺利进行进一步的实验:)
猜你喜欢
  • 1970-01-01
  • 2021-09-16
  • 2021-04-09
  • 2018-11-28
  • 2020-02-12
  • 2020-01-25
  • 2019-06-07
  • 2020-03-27
  • 2020-11-20
相关资源
最近更新 更多