【问题标题】:Does batch_first affect hidden tensors in Pytorch LSTMs?batch_first 会影响 Pytorch LSTM 中的隐藏张量吗?
【发布时间】:2018-10-05 08:16:18
【问题描述】:

batch_first 会影响 Pytorch LSTM 中的隐藏张量吗?

即如果batch_first参数为真, 隐藏状态会不会是(numlayer*direction,num_batch,encoding_dim)(num_batch,numlayer*direction,encoding_dim)

我都测试过,都没有错误。

【问题讨论】:

    标签: lstm pytorch


    【解决方案1】:

    来自docs

    batch_first - 如果为 True,则输入和输出张量提供为 (batch, seq, feature)

    所以,是的,如果您的输入是批处理优先,那么输出也将是批处理优先。

    【讨论】:

    • 这里的问题是隐藏状态是否相同。还是谢谢你。
    【解决方案2】:

    前段时间我也在思考同样的问题。就像laydog概述的那样,在文档中它说

    batch_first - 如果为 True,则提供输入和输出张量 as (batch, seq, feature)

    据我了解,我们谈论的是隐藏/单元状态元组,而不是实际的输入和输出。

    对我来说,这似乎很明显不会影响他们提到的隐藏状态:

    (批次、序列、特征)

    这显然是指输入和输出,而不是由两个具有形状的元组组成的状态元组:

    (num_layers * num_directions, batch, hidden_​​size)

    所以我很确定隐藏状态和单元状态不会受此影响,更改隐藏状态元组的顺序对我来说也没有意义。

    希望这会有所帮助。

    【讨论】:

      【解决方案3】:

      我们来看一个例子

      batch_size = 8
      sequence_length = 10
      input_dim = 64
      lstm = nn.LSTM(input_size=input_dim, hidden_size=32, num_layers=1, batch_first=True, bidirectional=False)
      lstm_input = torch.randn(batch_size, sequence_length, input_dim)
      output, (hidden, cell) = lstm(lstm_input)
      # >>> output.shape, hidden.shape, cell.shape
      # (torch.Size([8, 10, 32]), torch.Size([1, 8, 32]), torch.Size([1, 8, 32]))
      

      所以我们可以看到output 变量是批处理优先的,但隐藏和单元状态不是。

      当您想要输入初始隐藏或单元状态时,事情会变得有点复杂

      您可能期望这会起作用:

      initial_cell = torch.randn(batch_size, 1, 32)
      initial_hidden = torch.randn(batch_size, 1, 32)
      # WARNING, THIS DOES NOT WORK, IT IS AN EXAMPLE!
      # >>> output, (hidden, cell) = lstm(lstm_input, (initial_hidden, initial_cell))
      # RuntimeError: Expected hidden[0] size (1, 8, 32), got [8, 1, 32]
      

      这意味着输入隐藏和单元格状态必须是(sequence length, batch, hidden dim)格式。

      initial_cell = torch.randn(1, batch_size, 32)
      initial_hidden = torch.randn(1, batch_size, 32)
      output, (hidden, cell) = lstm(lstm_input, (initial_hidden, initial_cell))
      # >>> output.shape, hidden.shape, cell.shape
      # (torch.Size([8, 10, 32]), torch.Size([1, 8, 32]), torch.Size([1, 8, 32]))
      

      因此我们可以看到,无论batch_first 是什么,隐藏和单元格状态始终采用(seq, batch, dim) 格式,无论它是LSTM 单元格的输入参数还是输出参数。

      GRU 的隐藏状态也是如此。

      【讨论】:

        猜你喜欢
        • 2021-11-16
        • 2020-09-28
        • 2018-09-21
        • 2011-09-01
        • 1970-01-01
        • 1970-01-01
        • 2018-12-29
        • 2020-05-26
        • 2015-09-24
        相关资源
        最近更新 更多