【发布时间】:2020-09-12 00:34:09
【问题描述】:
我的模型中有一个图像作为输入,但我需要输入一些浮点数以及有关图像的支持信息,但我不希望它通过所有卷积,我希望它直接进入我的密集层作为如何训练它的信息。我知道连接层,但我不知道如何在输入中使用它,或者是否应该这样做。
【问题讨论】:
标签: python keras input concatenation
我的模型中有一个图像作为输入,但我需要输入一些浮点数以及有关图像的支持信息,但我不希望它通过所有卷积,我希望它直接进入我的密集层作为如何训练它的信息。我知道连接层,但我不知道如何在输入中使用它,或者是否应该这样做。
【问题讨论】:
标签: python keras input concatenation
假设您有一个backbone,它可以是任何卷积神经网络(VGG、ResNet 等)。在密集层之前,您通常有一个Flatten()(或者,在现代神经网络中,您通常有一个池化层,如GAP 或GeM),它准备一个一维向量作为密集层的输入。这就是你可以与你的浮动连接的地方。
使用函数式 API 的代码示例:
class MyModel(tf.keras.Model):
def __init__(self, num_output_classes):
super().__init__()
self.backbone = tf.keras.applications.ResNet50(
input_shape=(224, 224, 3), include_top=False)
self.pool = tf.keras.layers.GlobalAveragePooling2D()
self.concat = tf.keras.layers.Concatenate(axis=-1)
self.dense = tf.keras.layers.Dense(num_output_classes)
def call(self, inputs):
# Unpack the inputs. `additional_floats` should be 1D
image, additional_floats = inputs
# Run image through backbone and get a feature vector
x = self.backbone(image)
x = self.pool(x)
# Concatenate with your additional floats
x = self.concat([x, additional_inputs])
# Classification, or whatever you might need on top
return self.dense(x, activation='softmax')
【讨论】:
build 构建模型如果您的图层不支持浮点类型输入。相反,为了实例化和构建您的模型,call 您的模型基于真实张量数据(具有正确的 dtype)。