【问题标题】:Link weights in Keras layersKeras 层中的链接权重
【发布时间】:2019-08-23 03:57:40
【问题描述】:

假设我将输入分成两个大小相等的部分 I1、I2,并且我希望在我的 keras 网络上具有以下结构——I1->A1,I2->A2,然后是 [A1,A2]->B,其中B是输出节点。我可以使用1 中的组来执行此操作。但是,我想要求 I1->A1 的连接权重(和其他激活参数)与 I2->A2 的连接权重相同,即我希望 1 和 2 之间具有对称性。 (请注意,我不需要 [A1,A2]->B 的对称性。)

【问题讨论】:

    标签: keras


    【解决方案1】:

    如果我正确理解您的问题,则 input_1 到 A_1 和 input_2 到 A_2 的映射已经一个接一个地完成,因为您希望两个输入的映射函数相同。在这种情况下,您可能会考虑使用 RNN,但如果您的输入彼此独立,您可能会考虑使用TimeDistributedwrapper in Keras。下面的示例将采用两个输入,并使用Dense 层将输入一一映射,因此Dense 的权重是共享的:

    from keras.models import Model
    from keras.layers import Input, Dense, TimeDistributed, Concatenate, Lambda
    
    x_dim = 5
    hidden_dim = 8
    
    x1 = Input(shape=(1,x_dim,))
    x2 = Input(shape=(1,x_dim,))
    
    concat = Concatenate(axis=1)([x1, x2])
    hidden_concat = TimeDistributed(Dense(hidden_dim))(concat)
    hidden1 = Lambda(lambda x: x[:,:1,:])(hidden_concat)
    hidden2 = Lambda(lambda x: x[:,1:,:])(hidden_concat)
    
    model = Model(inputs=[x1,x2], outputs=[hidden1, hidden2])
    model.summary()
    
    >>>
    Layer (type)                    Output Shape         Param #     Connected to                     
    ==================================================================================================
    input_33 (InputLayer)           (None, 1, 5)         0                                            
    __________________________________________________________________________________________________
    input_34 (InputLayer)           (None, 1, 5)         0                                            
    __________________________________________________________________________________________________
    concatenate_17 (Concatenate)    (None, 2, 5)         0           input_33[0][0]                   
                                                                     input_34[0][0]                   
    __________________________________________________________________________________________________
    time_distributed_9 (TimeDistrib (None, 2, 8)         48          concatenate_17[0][0]             
    __________________________________________________________________________________________________
    lambda_8 (Lambda)               (None, 1, 8)         0           time_distributed_9[0][0]         
    __________________________________________________________________________________________________
    lambda_9 (Lambda)               (None, 1, 8)         0           time_distributed_9[0][0]         
    ==================================================================================================
    Total params: 48
    Trainable params: 48
    Non-trainable params: 0
    

    【讨论】:

    • 谢谢!这正是我一直在寻找的。 (RNN 不适合我的问题,)
    猜你喜欢
    • 2017-03-22
    • 1970-01-01
    • 2018-09-20
    • 1970-01-01
    • 2017-10-07
    • 2019-04-18
    • 2017-09-28
    • 2017-12-01
    • 2021-02-13
    相关资源
    最近更新 更多