【问题标题】:Implementing slice layer in keras在keras中实现切片层
【发布时间】:2020-02-11 11:47:21
【问题描述】:

(免责声明:我已将我的问题简化为重点,我想做的稍微复杂一些,但我在这里描述了核心问题。)

我正在尝试使用keras 构建一个网络来学习一些 5 x 5 矩阵的属性。

输入数据是一个 1000 x 5 x 5 numpy 数组的形式,其中每个 5 x 5 子数组代表一个矩阵。

我希望网络做的是使用矩阵中每一行的属性,所以我想将每个 5 x 5 数组拆分为单独的 1 x 5 数组,并将这 5 个数组中的每一个传递到下一个网络的一部分。

这是我目前所拥有的:

input_mat = keras.Input(shape=(5,5), name='Input')

part_list = list()   
for i in range(5):
    part_list.append(keras.layers.Lambda(lambda x: x[i,:])(input_mat)) 

dense_list = list()
for i in range(5):
    dense_list.append( keras.layers.Dense(10, activation='selu', 
                                          use_bias=True)(part_list[i]) )


conc = keras.layers.Concatenate(axis=-1, name='Concatenate')(dense_list)
dense_out = keras.layers.Dense(1, name='D_out', activation='sigmoid')(conc)


model = keras.Model(inputs= input_mat, outputs=dense_out)
model.compile(optimizer='adam', loss='mean_squared_error')

我的问题是这似乎训练得不好,查看模型摘要我不确定网络是否按照我的意愿拆分输入:

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
Input (InputLayer)              (None, 5, 5)         0                                            
__________________________________________________________________________________________________
lambda_5 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
lambda_6 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
lambda_7 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
lambda_8 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
lambda_9 (Lambda)               (5, 5)               0           Input[0][0]                      
__________________________________________________________________________________________________
dense (Dense)                   (5, 10)              60          lambda_5[0][0]                   
__________________________________________________________________________________________________
dense_1 (Dense)                 (5, 10)              60          lambda_6[0][0]                   
__________________________________________________________________________________________________
dense_2 (Dense)                 (5, 10)              60          lambda_7[0][0]                   
__________________________________________________________________________________________________
dense_3 (Dense)                 (5, 10)              60          lambda_8[0][0]                   
__________________________________________________________________________________________________
dense_4 (Dense)                 (5, 10)              60          lambda_9[0][0]                   
__________________________________________________________________________________________________
Concatenate (Concatenate)       (5, 50)              0           dense[0][0]                      
                                                                 dense_1[0][0]                    
                                                                 dense_2[0][0]                    
                                                                 dense_3[0][0]                    
                                                                 dense_4[0][0]                    
__________________________________________________________________________________________________
D_out (Dense)                   (5, 1)               51          Concatenate[0][0]                
==================================================================================================
Total params: 351
Trainable params: 351
Non-trainable params: 0

Lambda 层的输入和输出节点在我看来是错误的,但恐怕我仍然难以理解这个概念。

【问题讨论】:

标签: arrays tensorflow keras


【解决方案1】:

应避免使用 Lambda。

改为子类层:

class Slice(keras.layers.Layer):
    def __init__(self, begin, size,**kwargs):
        super(Slice, self).__init__(**kwargs)
        self.begin = begin
        self.size = size
    def get_config(self):

        config = super().get_config().copy()
        config.update({
            'begin': self.begin,
            'size': self.size,
        })
        return config
    def call(self, inputs):
        return tf.slice(inputs, self.begin, self.size)

【讨论】:

    【解决方案2】:

    在一行

    part_list.append(keras.layers.Lambda(lambda x: x[i,:])(input_mat)) 
    

    您基本上是在拍摄 1000 张图像中的前 5 张,这不是您想要做的。

    要实现你想要的,试试tensorflow的unstack操作:

    part_list = tf.unstack(input_mat, axis=1)
    

    这应该为您提供一个包含 5 个元素的列表,每个元素的形状为 [1000, 5]

    【讨论】:

    • 谢谢,但我最终得到了错误Output tensors to a Model must be the output of a TensorFlow Layer`(因此保留了过去的层元数据)。发现:Tensor("D_out_2/Sigmoid:0", shape=(?, 1), dtype=float32)` - 使用较低级别的tensorflow 模块似乎是个问题
    • 您可以将 op 包裹在 Lambda 层中或为其创建自定义层。
    猜你喜欢
    • 2019-02-17
    • 2017-08-19
    • 2019-09-03
    • 2011-02-25
    • 2021-02-21
    • 1970-01-01
    相关资源
    最近更新 更多