【问题标题】:How to apply a different dense layer to each row of a matrix in keras如何在keras中对矩阵的每一行应用不同的密集层
【发布时间】:2026-01-07 18:50:02
【问题描述】:

我上一层的输出具有形状 (None, 30, 600)。我想将此矩阵的每一行乘以 不同的 (600, 600) 矩阵或等效地将此矩阵乘以 3D 权重矩阵。这可以通过对每一行应用不同的密集层来实现。我尝试使用 TimeDistributed Wrapper,但这会将 same 密集层应用于每一行。我也尝试过像这样使用 lambda 层:

Lambda(lambda x: tf.stack(x, axis=1))(
    Lambda(lambda x: [Dense(600)(each) for each in tf.unstack(x, axis=1)])(prev_layer_output)
)

这似乎解决了问题,我能够正确训练模型。但我注意到 model.summary() 不能识别这些密集层,它们也没有反映在总可训练参数的计数中。此外,当我加载模型时,我无法恢复它们的权重,因此整个训练都被浪费了。我该如何解决这个问题?如何对矩阵的每一行应用不同的密集层?

【问题讨论】:

    标签: deep-learning keras keras-layer


    【解决方案1】:

    您可以使用多个层,而不是将所有内容包装到单个 Lambda 层中。

    x = Input((30, 600))
    unstacked = Lambda(lambda x: K.tf.unstack(x, axis=1))(x)
    dense_outputs = [Dense(600)(x) for x in unstacked]
    merged = Lambda(lambda x: K.stack(x, axis=1))(dense_outputs)
    model = Model(x, merged)
    

    现在您可以在model.summary() 中看到 30 个Dense(600) 层。

    __________________________________________________________________________________________________
    Layer (type)                    Output Shape         Param #     Connected to
    ==================================================================================================
    input_1 (InputLayer)            (None, 30, 600)      0
    __________________________________________________________________________________________________
    lambda_1 (Lambda)               [(None, 600), (None, 0           input_1[0][0]
    __________________________________________________________________________________________________
    dense_1 (Dense)                 (None, 600)          360600      lambda_1[0][0]
    __________________________________________________________________________________________________
    dense_2 (Dense)                 (None, 600)          360600      lambda_1[0][1]
    __________________________________________________________________________________________________
    dense_3 (Dense)                 (None, 600)          360600      lambda_1[0][2]
    __________________________________________________________________________________________________
    dense_4 (Dense)                 (None, 600)          360600      lambda_1[0][3]
    __________________________________________________________________________________________________
    dense_5 (Dense)                 (None, 600)          360600      lambda_1[0][4]
    __________________________________________________________________________________________________
    dense_6 (Dense)                 (None, 600)          360600      lambda_1[0][5]
    __________________________________________________________________________________________________
    dense_7 (Dense)                 (None, 600)          360600      lambda_1[0][6]
    __________________________________________________________________________________________________
    dense_8 (Dense)                 (None, 600)          360600      lambda_1[0][7]
    __________________________________________________________________________________________________
    dense_9 (Dense)                 (None, 600)          360600      lambda_1[0][8]
    __________________________________________________________________________________________________
    dense_10 (Dense)                (None, 600)          360600      lambda_1[0][9]
    __________________________________________________________________________________________________
    dense_11 (Dense)                (None, 600)          360600      lambda_1[0][10]
    __________________________________________________________________________________________________
    dense_12 (Dense)                (None, 600)          360600      lambda_1[0][11]
    __________________________________________________________________________________________________
    dense_13 (Dense)                (None, 600)          360600      lambda_1[0][12]
    __________________________________________________________________________________________________
    dense_14 (Dense)                (None, 600)          360600      lambda_1[0][13]
    __________________________________________________________________________________________________
    dense_15 (Dense)                (None, 600)          360600      lambda_1[0][14]
    __________________________________________________________________________________________________
    dense_16 (Dense)                (None, 600)          360600      lambda_1[0][15]
    __________________________________________________________________________________________________
    dense_17 (Dense)                (None, 600)          360600      lambda_1[0][16]
    __________________________________________________________________________________________________
    dense_18 (Dense)                (None, 600)          360600      lambda_1[0][17]
    __________________________________________________________________________________________________
    dense_19 (Dense)                (None, 600)          360600      lambda_1[0][18]
    __________________________________________________________________________________________________
    dense_20 (Dense)                (None, 600)          360600      lambda_1[0][19]
    __________________________________________________________________________________________________
    dense_21 (Dense)                (None, 600)          360600      lambda_1[0][20]
    __________________________________________________________________________________________________
    dense_22 (Dense)                (None, 600)          360600      lambda_1[0][21]
    __________________________________________________________________________________________________
    dense_23 (Dense)                (None, 600)          360600      lambda_1[0][22]
    __________________________________________________________________________________________________
    dense_24 (Dense)                (None, 600)          360600      lambda_1[0][23]
    __________________________________________________________________________________________________
    dense_25 (Dense)                (None, 600)          360600      lambda_1[0][24]
    __________________________________________________________________________________________________
    dense_26 (Dense)                (None, 600)          360600      lambda_1[0][25]
    __________________________________________________________________________________________________
    dense_27 (Dense)                (None, 600)          360600      lambda_1[0][26]
    __________________________________________________________________________________________________
    dense_28 (Dense)                (None, 600)          360600      lambda_1[0][27]
    __________________________________________________________________________________________________
    dense_29 (Dense)                (None, 600)          360600      lambda_1[0][28]
    __________________________________________________________________________________________________
    dense_30 (Dense)                (None, 600)          360600      lambda_1[0][29]
    __________________________________________________________________________________________________
    lambda_2 (Lambda)               (None, 30, 600)      0           dense_1[0][0]
                                                                     dense_2[0][0]
                                                                     dense_3[0][0]
                                                                     dense_4[0][0]
                                                                     dense_5[0][0]
                                                                     dense_6[0][0]
                                                                     dense_7[0][0]
                                                                     dense_8[0][0]
                                                                     dense_9[0][0]
                                                                     dense_10[0][0]
                                                                     dense_11[0][0]
                                                                     dense_12[0][0]
                                                                     dense_13[0][0]
                                                                     dense_14[0][0]
                                                                     dense_15[0][0]
                                                                     dense_16[0][0]
                                                                     dense_17[0][0]
                                                                     dense_18[0][0]
                                                                     dense_19[0][0]
                                                                     dense_20[0][0]
                                                                     dense_21[0][0]
                                                                     dense_22[0][0]
                                                                     dense_23[0][0]
                                                                     dense_24[0][0]
                                                                     dense_25[0][0]
                                                                     dense_26[0][0]
                                                                     dense_27[0][0]
                                                                     dense_28[0][0]
                                                                     dense_29[0][0]
                                                                     dense_30[0][0]
    ==================================================================================================
    Total params: 10,818,000
    Trainable params: 10,818,000
    Non-trainable params: 0
    __________________________________________________________________________________________________
    

    编辑:验证此模型是否正在学习:

    model.compile(loss='mse', optimizer='adam')
    w0 = model.get_weights()
    model.fit(np.random.rand(100,30,600), np.random.rand(100,30,600), epochs=10)
    

    你应该可以看到损失在减少:

    Epoch 1/10
    100/100 [==============================] - 1s 15ms/step - loss: 0.4725
    Epoch 2/10
    100/100 [==============================] - 0s 1ms/step - loss: 0.2211
    Epoch 3/10
    100/100 [==============================] - 0s 1ms/step - loss: 0.2405
    Epoch 4/10
    100/100 [==============================] - 0s 1ms/step - loss: 0.2013
    Epoch 5/10
    100/100 [==============================] - 0s 1ms/step - loss: 0.1771
    Epoch 6/10
    100/100 [==============================] - 0s 1ms/step - loss: 0.1676
    Epoch 7/10
    100/100 [==============================] - 0s 1ms/step - loss: 0.1568
    Epoch 8/10
    100/100 [==============================] - 0s 1ms/step - loss: 0.1473
    Epoch 9/10
    100/100 [==============================] - 0s 1ms/step - loss: 0.1400
    Epoch 10/10
    100/100 [==============================] - 0s 1ms/step - loss: 0.1343
    

    此外,您可以通过比较模型拟合前后的值来验证权重是否确实得到了更新:

    w0 = model.get_weights()
    model.fit(np.random.rand(100,30,600), np.random.rand(100,30,600), epochs=10)
    
    w1 = model.get_weights()
    print(not any(np.allclose(x0, x1) for x0, x1 in zip(w0, w1)))
    # => True
    

    【讨论】:

    • 我试过这样做。它允许模型编译,但不会应用密集层的权重更新(在反向传播期间计算)。因此,模型没有学习。
    • @hal9000 我试着用一些随机数据编译和拟合这个模型。由于损失在每个时期都在减少,因此该模型似乎正在学习。您能否详细说明您的尝试?也许您的代码中还有其他问题。
    • 查看我的编辑以验证模型是否正在学习。也许您还可以检查是否有任何区别。
    • 是的,你是对的,权重确实会更新。感谢您的帮助!
    【解决方案2】:

    您可以将 (30,600) 矩阵与 (600,30,600) 相乘,这样您将获得 (600,30,600),如果您对最后一个维度求和,您应该得到您想要转置的结果。我在 numpy 中测试了这个,而不是在 tensorflow 中,但它应该是一样的

    【讨论】:

      最近更新 更多