由于我的问题没有得到任何答案,我想分享一下我想出的解决方案,以防有人通过搜索找到这个问题。
如果您找到或提出更好的解决方案,请告诉我 - 谢谢!
class SimpleModel(tf.keras.Model):
def __init__(self, input_shape, *args, **kwargs):
super(SimpleModel, self).__init__(*args, **kwargs)
# Create node layers
self.node_1 = tf.keras.layers.InputLayer(input_shape=input_shape)
self.node_2 = tf.keras.layers.Dense(64, activation='sigmoid')
self.node_3 = tf.keras.layers.Dense(32, activation='sigmoid')
self.node_4 = tf.keras.layers.Dense(16, activation='sigmoid')
self.conn_3_2_recurrent_state = None
# Create recurrent connection states
node_1_output_shape = self.node_1.compute_output_shape(input_shape)
node_2_output_shape = self.node_2.compute_output_shape(node_1_output_shape)
node_3_output_shape = self.node_3.compute_output_shape(node_2_output_shape)
self.conn_3_2_recurrent_state = tf.Variable(initial_value=self.node_3(tf.ones(shape=node_2_output_shape)),
trainable=False,
validate_shape=False,
dtype=tf.float32)
# OR
# self.conn_3_2_recurrent_state = tf.random.uniform(shape=node_3_output_shape, minval=0.123, maxval=4.56)
# OR
# self.conn_3_2_recurrent_state = tf.ones(shape=node_3_output_shape)
# OR
# self.conn_3_2_recurrent_state = tf.zeros(shape=node_3_output_shape)
def call(self, inputs):
x = self.node_1(inputs)
#tf.print(self.conn_3_2_recurrent_state)
#tf.print(self.conn_3_2_recurrent_state.shape)
x = tf.keras.layers.Concatenate(axis=-1)([x, self.conn_3_2_recurrent_state])
x = self.node_2(x)
x = self.node_3(x)
self.conn_3_2_recurrent_state.assign(x)
#tf.print(self.conn_3_2_recurrent_state)
#tf.print(self.conn_3_2_recurrent_state.shape)
x = self.node_4(x)
return x
# Demonstrate statefulness of model (uncomment tf prints in model.call())
model = SimpleModel(input_shape=(10, 128))
x = tf.ones(shape=(10, 128))
model(x)
model(x)
# Demonstrate trainability of the recurrent connection TF model
x = tf.random.uniform(shape=(10, 128))
y = tf.ones(shape=(10, 16))
model = SimpleModel(input_shape=(10, 128))
model.compile(optimizer='adam', loss='binary_crossentropy')
model.fit(x=x, y=y, epochs=100)