【问题标题】:How to convert pytorch lstmcell to keras lstm or lstmcell如何将 pytorch lstmcell 转换为 keras lstm 或 lstmcell
【发布时间】:2018-10-04 13:23:19
【问题描述】:

这是 pytorch lstmcell 的示例:

rnn = nn.LSTMCell(10, 20)
input = torch.randn(6, 3, 10)
hx = torch.randn(3, 20)
cx = torch.randn(3, 20)
output = []
hx, cx = rnn(input[0], (hx, cx))
output.append(hx)

不确定如何将其转换为 keras lstm/lstmcell

【问题讨论】:

  • 到目前为止你尝试过什么?你看过 keras 文档吗?
  • 是的,原码:self.att_lstm = nn.LSTMCell(1536, 512) h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0] ])) 我所尝试的,输入 = Input(shape=(10, 1536)) lstm, h_att, c_att = LSTM(units=512, input_shape=(10,1536), name='core.att_lstm', return_state= True)(inputs) 所以我不确定它是否正确
  • 请,不要在 cmets 中张贴代码 - 它实际上是不可读的!改为编辑和更新您的帖子!

标签: keras


【解决方案1】:

原始 Pytorch 代码:
self.att_lstm = nn.LSTMCell(1536, 512)
h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0]))
state[0][0], state[1][0] 是张量(10,512) 我在 keras 中尝试过的内容: inputs = Input(shape=(10, 1536))
lstm, h_att, c_att = LSTM(units=512, input_shape=(10,1536), name='core.att_lstm', return_state=True)(inputs)
所以我不确定它是否正确。

【讨论】:

    猜你喜欢
    • 2019-09-28
    • 1970-01-01
    • 2022-06-16
    • 2020-02-09
    • 2021-07-01
    • 2018-06-29
    • 1970-01-01
    • 1970-01-01
    • 2021-07-10
    相关资源
    最近更新 更多