【问题标题】:What is equivalent to pytorch lstm num_layers?什么相当于 pytorch lstm num_layers?
【发布时间】:2021-11-11 12:29:05
【问题描述】:

我是 PyTorch 的初学者。从 lstm description,我了解到我可以通过以下方式创建具有 3 层的堆叠 lstm:

layer = torch.nn.LSTM(128, 512, num_layers=3)

那么在forward函数中,我可以这样做:

def forward(x, state):
    x, state = layer(x, state)
    return x, (state[0].detach(), state[1].detach())

我可以将state 逐批传递。
但是如果我创建了 3 个 lstm 层,如果我想自己实现相同的堆叠层,那相当于什么?

layer1 = torch.nn.LSTM(128, 512, num_layers=1)
layer2 = torch.nn.LSTM(128, 512, num_layers=1)
layer3 = torch.nn.LSTM(128, 512, num_layers=1)

在这种情况下,应该进入forward函数并获取返回的state
我还尝试查看 pytorch lstm 的 source code,但在 forward 函数中它调用了 _VF 模块,我找不到它的定义位置。

【问题讨论】:

    标签: python pytorch lstm


    【解决方案1】:

    如果您将state 定义为 3 层状态的列表,那么

    def forward(x, state):
        x, s0 = layer1(x, state[0])
        x, s1 = layer2(x, state[1])
        x, s2 = layer3(x, state[2])
        return x, [s0.detach(), s1.detach(), s2.detach()]
    

    【讨论】:

      猜你喜欢
      • 2018-08-19
      • 1970-01-01
      • 2017-12-07
      • 2019-09-18
      • 2020-05-05
      • 2019-12-21
      • 2018-06-21
      • 2019-11-27
      • 2021-12-04
      相关资源
      最近更新 更多