【问题标题】:Keras simple RNN implementationKeras 简单的 RNN 实现
【发布时间】:2018-02-25 15:08:39
【问题描述】:

我在尝试编译具有一个循环层的网络时发现了问题。第一层的维度似乎存在一些问题,因此我对 RNN 层在 Keras 中的工作方式的理解存在问题。

我的代码示例是:

model.add(Dense(8,
                input_dim = 2,
                activation = "tanh",
                use_bias = False))
model.add(SimpleRNN(2,
                    activation = "tanh",
                    use_bias = False))
model.add(Dense(1,
                activation = "tanh",
                use_bias = False))

错误是

ValueError: Input 0 is incompatible with layer simple_rnn_1: expected ndim=3, found ndim=2

无论input_dim 值如何,都会返回此错误。我错过了什么?

【问题讨论】:

    标签: machine-learning neural-network keras recurrent-neural-network rnn


    【解决方案1】:

    该消息的意思是:进入 rnn 的输入有 2 个维度,但 rnn 层需要 3 个维度。

    对于 RNN 层,您需要形状类似于 (BatchSize, TimeSteps, FeaturesPerStep) 的输入。这些是预期的 3 个维度。

    Dense 层(在 keras 2 中)可以使用 2 维或 3 维。我们可以看到您正在使用 2,因为您传递了 input_dim 而不是 input_shape=(Steps,Features)

    有很多可能的方法可以解决这个问题,但最有意义和最合乎逻辑的情况是您的输入数据是具有时间步长的序列。

    解决方案 1 - 您的训练数据是一个序列:

    如果您的训练数据是一个序列,您可以将其塑造成(NumberOfSamples, TimeSteps, Features) 并将其传递给您的模型。确保在第一层使用input_shape=(TimeSteps,Features) 而不是input_dim

    解决方案 2 - 重塑第一个密集层的输出,使其具有额外的维度:

    model.add(Reshape((TimeSteps,Features)))
    

    确保乘积 TimeSteps*Features 等于 8,即第一个密集层的输出。

    【讨论】:

    • 太棒了,谢谢,还有一个问题。第一个解决方案工作得很好,但是如果我想要无限的时间步长(理论上的问题,我知道,无限的时间步长是愚蠢的)怎么办?然后,我必须使用您的第二个解决方案来重塑第一层的输出。但是,我用记忆异或序列做了简单的测试,当我洗牌输出时,网络没有像我预期的那样对它做出反应。更好的说法是,它返回的输出与 shuffle 之前相同。重塑究竟如何影响循环层的工作(与第一个解决方案相比)?
    • 重塑只是获取数据(任何数据),这只不过是一个分割成段的数字序列。假设您有 300 个元素。当你像 (30,10,1) 那样重塑它们时,你只是以不同的方式分离了这 300 个元素。因此,如果您出于序列目的进行重塑,您必须了解您想要实现的目标以及数据的格式,这样您才能以一种重要的方式对其进行重塑。
    • 对于您的无限序列,您可能应该使用只有 1 个样本 (BatchSize=1, TimeSteps, Features) 的输入,并用 stateful=True 标记您的循环层。这意味着这些层将保持它们的记忆,并且下一批将被视为在一个序列中继续前一批。在这种情况下,当您决定一个序列结束并开始提供另一个序列时,您必须手动“擦除内存”(称为“重置状态”)。
    • 丹尼尔·穆勒 谢谢。
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多