【问题标题】:Simple Recurrent Neural Network input shape简单的循环神经网络输入形状
【发布时间】:2016-11-12 15:20:42
【问题描述】:

我正在尝试使用 keras 编写一个非常简单的 RNN 示例,但结果与预期不符。

我的 X_train 是一个长度为 6000 的重复列表,例如:1, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...

我将其格式化为:(6000, 1, 1)

我的 y_train 是一个长度为 6000 的重复列表,例如:1, 0.8, 0.6, 0, 0, 0, 1, 0.8, 0.6, 0, ...

我将其格式化为:(6000, 1)

在我的理解中,循环神经网络应该学会正确预测 0.8 和 0.6,因为它可以记住两个时间步前 X_train 中的 1。

我的模特:

model=Sequential()
model.add(SimpleRNN(input_dim=1, output_dim=50))
model.add(Dense(output_dim=1, activation = "sigmoid"))
model.compile(loss="mse", optimizer="rmsprop")
model.fit(X_train, y_train, nb_epoch=10, batch_size=32)

模型可以成功训练,损失最小~0.1015,但结果不如预期。

test case ---------------------------------------------  model result -------------expected result 

model.predict(np.array([[[1]]])) --------------------0.9825--------------------1

model.predict(np.array([[[1],[0]]])) ----------------0.2081--------------------0.8

model.predict(np.array([[[1],[0],[0]]])) ------------0.2778 -------------------0.6

model.predict(np.array([[[1],[0],[0],[0]]]))---------0.3186--------------------0

任何提示我在这里误解了什么?

【问题讨论】:

    标签: python neural-network keras recurrent-neural-network


    【解决方案1】:

    输入格式应该是三维的:三个分量分别代表样本大小、时间步数和输出维度

    一旦适当地重新格式化,RNN 确实可以很好地预测目标序列。

    np.random.seed(1337)
    
    sample_size = 256
    x_seed = [1, 0, 0, 0, 0, 0]
    y_seed = [1, 0.8, 0.6, 0, 0, 0]
    
    x_train = np.array([[x_seed] * sample_size]).reshape(sample_size,len(x_seed),1)
    y_train = np.array([[y_seed]*sample_size]).reshape(sample_size,len(y_seed),1)
    
    model=Sequential()
    model.add(SimpleRNN(input_dim  =  1, output_dim = 50, return_sequences = True))
    model.add(TimeDistributed(Dense(output_dim = 1, activation  =  "sigmoid")))
    model.compile(loss = "mse", optimizer = "rmsprop")
    model.fit(x_train, y_train, nb_epoch = 10, batch_size = 32)
    
    print(model.predict(np.array([[[1],[0],[0],[0],[0],[0]]])))
    #[[[ 0.87810659]
    #[ 0.80646527]
    #[ 0.61600274]
    #[ 0.01652312]
    #[ 0.00930419]
    #[ 0.01328572]]]
    

    【讨论】:

    • + 函数式,我第一次运行 RNN 模型 :)
    • 为 1337 提供 +1,保持这些传统的活力。
    猜你喜欢
    • 2020-12-28
    • 2016-08-18
    • 1970-01-01
    • 2017-07-30
    • 2017-08-31
    • 2020-05-26
    • 1970-01-01
    • 2017-12-05
    • 2020-03-16
    相关资源
    最近更新 更多