【问题标题】:How to create a recurrent connection between 2 layers in Tensorflow/Keras?如何在 Tensorflow/Keras 中的 2 层之间创建循环连接?
【发布时间】:2021-01-02 01:34:52
【问题描述】:

基本上我想做的是采用以下非常简单的前馈图:

然后添加一个循环层,将第二个密集层的输出作为输入提供给第一个密集层,如下所示。这两种模型显然都是对我的实际用例的简化,尽管我认为我所要求的一般原则对两者都适用。

我想知道 Tensorflow 甚至 keras 中是否有一种有效的方法来实现这一点,尤其是在 GPU 处理效率方面。虽然我相当有信心可以在 Tensorflow 中组合一个自定义模型来完成这个功能,但我是否对这种自定义模型的 GPU 处理效率感到悲观。因此,如果有人知道一种有效的方法来完成这些两层之间的循环连接,我将非常感激。感谢您的时间! =)


为了完整起见,这里是创建第一个简单前馈图的代码。我通过图像编辑创建的循环图。

inputs = tf.keras.Input(shape=(128,))

h_1 = tf.keras.layers.Dense(64)(inputs)
h_2 = tf.keras.layers.Dense(32)(h_1)
out = tf.keras.layers.Dense(16)(h_2)

model = tf.keras.Model(inputs, out)

【问题讨论】:

    标签: tensorflow keras recurrent-neural-network


    【解决方案1】:

    由于我的问题没有得到任何答案,我想分享一下我想出的解决方案,以防有人通过搜索找到这个问题。

    如果您找到或提出更好的解决方案,请告诉我 - 谢谢!

    class SimpleModel(tf.keras.Model):
        def __init__(self, input_shape, *args, **kwargs):
            super(SimpleModel, self).__init__(*args, **kwargs)
            # Create node layers
            self.node_1 = tf.keras.layers.InputLayer(input_shape=input_shape)
            self.node_2 = tf.keras.layers.Dense(64, activation='sigmoid')
            self.node_3 = tf.keras.layers.Dense(32, activation='sigmoid')
            self.node_4 = tf.keras.layers.Dense(16, activation='sigmoid')
            self.conn_3_2_recurrent_state = None
    
            # Create recurrent connection states
            node_1_output_shape = self.node_1.compute_output_shape(input_shape)
            node_2_output_shape = self.node_2.compute_output_shape(node_1_output_shape)
            node_3_output_shape = self.node_3.compute_output_shape(node_2_output_shape)
    
            self.conn_3_2_recurrent_state = tf.Variable(initial_value=self.node_3(tf.ones(shape=node_2_output_shape)),
                                                        trainable=False,
                                                        validate_shape=False,
                                                        dtype=tf.float32)
            # OR
            # self.conn_3_2_recurrent_state = tf.random.uniform(shape=node_3_output_shape, minval=0.123, maxval=4.56)
            # OR
            # self.conn_3_2_recurrent_state = tf.ones(shape=node_3_output_shape)
            # OR
            # self.conn_3_2_recurrent_state = tf.zeros(shape=node_3_output_shape)
    
        def call(self, inputs):
            x = self.node_1(inputs)
    
            #tf.print(self.conn_3_2_recurrent_state)
            #tf.print(self.conn_3_2_recurrent_state.shape)
    
            x = tf.keras.layers.Concatenate(axis=-1)([x, self.conn_3_2_recurrent_state])
            x = self.node_2(x)
            x = self.node_3(x)
    
            self.conn_3_2_recurrent_state.assign(x)
            #tf.print(self.conn_3_2_recurrent_state)
            #tf.print(self.conn_3_2_recurrent_state.shape)
    
            x = self.node_4(x)
            return x
    
    
    # Demonstrate statefulness of model (uncomment tf prints in model.call())
    model = SimpleModel(input_shape=(10, 128))
    x = tf.ones(shape=(10, 128))
    model(x)
    model(x)
    
    
    # Demonstrate trainability of the recurrent connection TF model
    x = tf.random.uniform(shape=(10, 128))
    y = tf.ones(shape=(10, 16))
    
    model = SimpleModel(input_shape=(10, 128))
    model.compile(optimizer='adam', loss='binary_crossentropy')
    model.fit(x=x, y=y, epochs=100)
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2018-04-26
      • 2017-01-06
      • 1970-01-01
      • 2019-07-29
      • 1970-01-01
      • 1970-01-01
      • 2017-08-29
      相关资源
      最近更新 更多