【问题标题】:Properly shaping LSTM input/output for text sequence prediction?为文本序列预测正确塑造 LSTM 输入/输出?
【发布时间】:2020-07-21 00:11:43
【问题描述】:

我正在尝试学习如何使用 LSTM 根据前一个字符预测下一个字符。我已经成功创建了一个函数,可以将一长串文本(在一次热编码之后)批处理成(batch_size, seq_len, one_hot_features) 的形状。 one_hot_features 只是文本中唯一字符的数量。

由于我的批处理功能为我提供了训练序列和“基本事实”序列,我需要 LSTM 网络的输出与“基本事实”批次的形状相同,以便我可以将输出和标签插入损失函数。

所以我的问题是:

  1. 如何正确定义我的网络架构以满足上述要求。
  2. 当我的数据通过forward 函数中的各个层时,我如何对其进行整形?

这是我的尝试,代码将运行,但网络输出的形状不正确,无法与“ground truth”序列批次进行比较:

class charNN(nn.Module):
    def __init__(self, vocab, hidden_size, n_layers, dropout=0.5):
        super().__init__()

        self.vocab_length = len(vocab)
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(self.vocab_length, hidden_size, n_layers, batch_first=True)
        self.fc   = nn.Linear(hidden_size, self.vocab_length)
        
    def forward(self, x, hidden=None):
        x, hidden = self.lstm(x, hidden)     # -> (n_batches, seq_len, hidden_size)
        x = x.reshape(-1, self.hidden_size)  # -> (n_batches * seq_len, hidden_size)
        x = self.fc(x)                       # -> (n_batches * seq_len, vocab_length)
        
        return x, hidden    

提前谢谢你。我已经为此苦苦挣扎了几天。如果需要,我也很乐意提供用于训练网络的代码。它不会抛出任何错误,只是告诉我输出大小和真实标签大小不匹配。

【问题讨论】:

  • 嗨,您是否考虑过添加额外的密集层甚至线性层以使 lstm 的输出达到所需的形状?然后,您还需要相应地调整 lstm 模块的隐藏层以实现目标形状。这两者的结合给了你想要的。您的模块固有的优化方法将确保这个额外的层不会影响训练的准确性。最佳
  • @smile 这是一个有趣的想法,但我很难想象你的意思。由于我仍在学习 LSTM(以及一般的 DL),我认为我更喜欢一种简单且易于推广到其他类型数据的解决方案。出于这个原因,我希望在开始添加新层之前找到最简单的架构。
  • @rockNwaves 我明白你的意思。我猜你已经对整个架构有了一个总体规划。但是您还需要确保体系结构的线性代数检查出来并且可以根据您想要的输出进行解释。因此,例如,您可能决定使用额外的 lstm 单元,或添加其他层或任何其他有助于提供所需输出的操作。您还可以观察到,在建议的答案中,这就是它正在尝试做的事情。负责架构,让它为你工作。最佳
  • @smile,本身不是一个计划,只是希望保持简单,直到我了解发生了什么。所以你是说如果不添加额外的层,我想做的事情可能无法实现?这本身就是一个答案......
  • @rockNwaves 我认为您还应该检查 hidden_​​size、n_layers 的变化如何影响输出的形状,或者 lstm 单元keras lstm 的参数变化如何改变输出大小?例如,将 stateful 设置为 true 将产生新的行为。如果这不起作用,那么附加层会有所帮助。从最少的附加层开始,因此很容易理解。然后朝着此处提供的答案前进。最佳

标签: python machine-learning neural-network pytorch lstm


【解决方案1】:

您的input_size: len(vocab) 并且您还为__init__ 方法提供了足够的数量信息。

我也是 lstm 的初学者。

答案 1:

据我说,您已经正确定义了您的架构。也许您可以添加以下内容:

  • token_size 变量:你的输出大小

  • encoderdecoder 结构。

  • 权重初始化方法

例如:

class charNN(nn.Module):
    def __init__(self, vocab, hidden_size, n_layers, dropout=0.5, token_size):
        super(charNN, self).__init__()
        self.token_size = token_size
        self.drop = nn.Dropout(dropout)
        self.vocab_length = len(vocab)
        self.hidden_size = hidden_size
        self.encoder = nn.Embedding(token_size, vocab_length)
        self.lstm = nn.LSTM(self.vocab_length,   
                            hidden_size, 
                            n_layers, 
                            batch_first=True, 
                            dropout=dropout)
        self.fc   = nn.Linear(hidden_size,  
                              self.vocab_length)
        self.decoder = nn.Linear(hidden_size, token_size)
        self.init_weights()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

    def init_weights(self):
        init_range = 0.1
        uniform_(self.encoder.weight, -init_range, init_range)
        zeros_(self.decoder.weight)
        uniform_(self.decoder.weight, -init_range, init_range)
        

    def init_hidden(self, batch_size):
        weight = next(self.parameters())
        return (weight.new_zeros(self.num_layers, batch_size,
                                 self.hidden_size),
                weight.new_zeros(self.num_layers, batch_size,
                                 self.hidden_size))

答案 2:

您可以将以下结构用于转发功能:

    def forward(self, input_, hidden_):
        embedded = self.drop(self.encoder(input_))
        output, hidden_ = self.rnn(input_, hidden_)
        output = self.drop(output)
        decoded = self.decoder(output)
        decoded = decoded.view(-1, self.token_size)
        output = log_softmax(decoded, dim=1)
        return output, hidden_

希望对你有帮助,祝你好运!

【讨论】:

  • 嗨!变量token_size 与我已经定义为vocab_length 的变量相同。我也对增加额外复杂性的解决方案不感兴趣,因为我仍在学习。如果您能按照说明回答问题,将不胜感激。感谢您的宝贵时间!
猜你喜欢
  • 1970-01-01
  • 2021-09-30
  • 1970-01-01
  • 2018-02-20
  • 1970-01-01
  • 1970-01-01
  • 2018-06-25
  • 2019-11-19
  • 1970-01-01
相关资源
最近更新 更多