【问题标题】:Extracting hidden representations for each token - PyTorch LSTM提取每个标记的隐藏表示 - PyTorch LSTM
【发布时间】:2021-07-22 09:23:37
【问题描述】:

我目前正在从事一个涉及循环神经网络的 NLP 项目。我按照教程here 使用 PyTorch 实现了 LSTM。

对于我的项目,我需要为输入文本的每个标记提取隐藏表示。我认为最简单的方法是使用批量大小和序列长度 1 进行测试,但是当我这样做时,损失会比训练阶段大几个数量级(在训练期间,我使用了 64 的批量大小和序列长度35)。

有没有其他方法可以轻松访问这些单词级隐藏表示?谢谢。

【问题讨论】:

    标签: nlp pytorch lstm recurrent-neural-network


    【解决方案1】:

    是的,只要是单层 LSTM,nn.LSTM 就可以做到这一点。如果您查看文档 (here),对于 LSTM 的输出,您可以看到它输出一个张量和一个张量元组。元组包含最后一个序列步骤的隐藏和单元格。每个维度对输出的含义取决于您如何初始化网络。第一个或第二个维度是批量维度,其余的是您想要的词嵌入序列。

    如果你使用打包序列作为输入,那就有点不同了。

    【讨论】:

      猜你喜欢
      • 2018-09-21
      • 1970-01-01
      • 1970-01-01
      • 2014-12-01
      • 2018-10-05
      • 2018-06-26
      • 1970-01-01
      • 2019-07-31
      • 2015-01-27
      相关资源
      最近更新 更多