【发布时间】:2018-10-04 13:23:19
【问题描述】:
这是 pytorch lstmcell 的示例:
rnn = nn.LSTMCell(10, 20)
input = torch.randn(6, 3, 10)
hx = torch.randn(3, 20)
cx = torch.randn(3, 20)
output = []
hx, cx = rnn(input[0], (hx, cx))
output.append(hx)
不确定如何将其转换为 keras lstm/lstmcell
【问题讨论】:
-
到目前为止你尝试过什么?你看过 keras 文档吗?
-
是的,原码:self.att_lstm = nn.LSTMCell(1536, 512) h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0] ])) 我所尝试的,输入 = Input(shape=(10, 1536)) lstm, h_att, c_att = LSTM(units=512, input_shape=(10,1536), name='core.att_lstm', return_state= True)(inputs) 所以我不确定它是否正确
-
请,不要在 cmets 中张贴代码 - 它实际上是不可读的!改为编辑和更新您的帖子!
标签: keras