【问题标题】:Input_shape for build method in TensorFlow custom layer with multiple inputs具有多个输入的 TensorFlow 自定义层中构建方法的 Input_shape
【发布时间】:2021-06-13 07:02:52
【问题描述】:

我必须设计一个接受两个输入 X_1X_2 的神经网络。该层将它们转换为固定大小的向量(10D),然后按以下方式对它们求和

class my_lyr(tf.keras.layers.Layer):
    def __init__(self):
        pass
    def call(self, X_1, X_2):
        return X_1 @ self.w1 + X_2 @ self.w2  

但是,在初始化w1w2 之前,我需要知道X_1X_2 的输入形状。 我不确定如何在build 中声明w2

def build(self, input_shape):
    self.w1 = self.add_weight('w1', shape=[input_shape[-1],10])
    // self.w2 = ?????

我想知道在这种情况下通常如何构建方法。

【问题讨论】:

    标签: python tensorflow machine-learning deep-learning neural-network


    【解决方案1】:

    如果你有两个这样的层输入,那么你可以简单地初始化你的权重,如下所示

    import tensorflow as tf 
    from tensorflow import keras 
    
    class Linear(keras.layers.Layer):
        def __init__(self, units=32):
            super(Linear, self).__init__()
            self.units = units
    
        def build(self, input_shape):
            self.wa = self.add_weight(
                shape=(input_shape[0][-1], self.units),
                initializer="random_normal",
                trainable=True,
            )
    
            self.wb = self.add_weight(
                shape=(input_shape[1][-1], self.units),
                initializer="random_normal",
                trainable=True,
            )
    
        def call(self, inputs):
            return tf.matmul(inputs[0], self.wa) + tf.matmul(inputs[1], self.wb)
    

    传递输入

    x = tf.random.normal(shape=(2,2))
    linear_layer = Linear(32)
    linear_layer([x, x])
    
    <tf.Tensor: shape=(2, 32), dtype=float32, numpy=
    array([[-0.08829461, -0.01605312, -0.04368614, -0.08116315, -0.01521384,
             0.01132785,  0.10704445, -0.10873697, -0.0525714 ,  0.07684848,
             0.04586978,  0.01315852,  0.01369547,  0.07404792,  0.10313608,
            -0.10851607,  0.04091477, -0.01723676, -0.0326797 ,  0.03598418,
            -0.11335816, -0.10044714,  0.13555384,  0.01689356,  0.02631954,
             0.08226107, -0.08765724, -0.05981663,  0.00531629,  0.02930426,
             0.04155847,  0.05339598],
           [ 0.20617458, -0.05936547,  0.01735754, -0.06575315,  0.10090968,
            -0.07796012, -0.1956767 , -0.03406558,  0.18604615, -0.03547171,
             0.02784208,  0.0471364 , -0.10712875, -0.07869454, -0.19457275,
             0.13593757, -0.14659101,  0.0384632 ,  0.02344182, -0.03861775,
             0.08948556,  0.09225713, -0.17395493,  0.10021958, -0.09210777,
            -0.09865301,  0.2536609 , -0.02547608,  0.02885125, -0.01271547,
            -0.10340843, -0.0338558 ]], dtype=float32)>
    

    【讨论】:

    • 好的,所以基本上必须使用[x x]
    猜你喜欢
    • 2021-10-17
    • 2020-10-04
    • 2020-01-21
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-11-09
    相关资源
    最近更新 更多