【发布时间】:2020-12-31 23:05:06
【问题描述】:
我正在尝试实现一个卷积自动编码器,其中一些卷积过滤器依赖于输入内容。例如,在一个简单的玩具示例中,了解 MNIST 的数字标签可以进一步帮助在自动编码器设置中进行重建。
更一般的想法是,可能存在一些有用的相关辅助信息(无论是类标签还是其他一些信息)。虽然有多种方法可以使用此标签/辅助信息,但我将通过创建一个单独的卷积过滤器来实现。假设该模型有 15 个典型的卷积滤波器,我想添加一个对应于 MNIST 数字的额外卷积滤波器,可以将其视为 3x3 内核形式的数字嵌入。我们将使用该数字作为网络的附加输入,然后为每个数字学习不同的内核/过滤器嵌入。
但是,我在实现依赖于输入的卷积过滤器/内核时遇到了困难。我没有使用tf.keras.layers.Conv2D 层,因为它接受了要使用的过滤器的数量,而不是实际的过滤器参数来使这个输入依赖。
# load and preprocess data
num_classes = 10
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = np.float32(x_train)/255, np.float32(x_test)/255
x_train, x_test = np.expand_dims(x_train, axis=-1), np.expand_dims(x_test, axis=-1)
y_train = keras.utils.to_categorical(y_train, num_classes=num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes=num_classes)
num_filters = 15
input_img = layers.Input(shape=(28,28,1))
conv_0 = keras.layers.Conv2D(num_filters, (3,3), strides=2, padding='same', activation='relu')(input_img)
# embed the target as a 3x3 kernel/filter -> this should map to a distinct embedding for
# each target
target = layers.Input(shape=(10,))
target_encoded = layers.Dense(9, activation='relu')(target)
target_encoded = layers.Reshape((3,3,1,1))(target_encoded)
# Using tf.nn.conv2d so that I can specify kernel
# Kernel needs to be a 4D tensor of dimensions (filter_height, filter_width, input_channels, output_channels)
# which in this case is (3,3,1,1)
# However it is currently (None,3,3,1,1) because the first dimension is batch size so this doesn't work
target_conv = tf.nn.conv2d(input_img, target_encoded, strides=[1, 1, 1, 1], padding='SAME')
我目前正在使用tf.nn.conv2d,它将内核作为格式的输入(filter_height、filter_width、input_channels、output_channels)。但是,这不起作用,因为数据是分批输入的。因此,批次中的每个样本都有一个标签,因此有一个相应的内核,因此内核的形状 (None, 3, 3, 1, 1) 与预期格式不兼容。这在上面的代码块中进行了说明(不起作用)。什么是潜在的解决方法?有没有更简单的方法来实现这个依赖于输入的 conv2d 过滤器的概念?
【问题讨论】:
-
您将如何在网络上执行推理?听起来您需要输入包含真实数字才能使网络正常工作。理想结构的问题在于,将真实标签作为输入和输出,优化的 CNN 将学习身份函数
f(x)=x。也就是说,您的网络将学会仅考虑输入标签,将所有其他像素乘以 0,将输入标签乘以 1。这样,您的 CNN 根本不会学习。 -
@ibarrond 这将是在自动编码器或变分自动编码器的上下文中,而不是在我们试图预测标签的分类中。因此网络获取图像和一些辅助信息(可能是标签)并输出重建图像(或在 vaes 上下文中生成的图像)。目标是看看知道标签或其他一些信息是否有助于重建/生成。
标签: tensorflow keras neural-network conv-neural-network convolution