Flatten 层的作用
经过卷积运算,tf.keras.layers.Flatten 会将张量重塑为(n_samples, height*width*channels),例如将(16, 28, 28, 3) 转换为(16, 2352)。让我们试试吧:
import tensorflow as tf
x = tf.random.uniform(shape=(100, 28, 28, 3), minval=0, maxval=256, dtype=tf.int32)
flat = tf.keras.layers.Flatten()
flat(x).shape
TensorShape([100, 2352])
GlobalAveragePooling 层的作用
在卷积操作之后,tf.keras.layers.GlobalAveragePooling 层确实是根据最后一个轴对所有值进行平均。这意味着生成的形状将是(n_samples, last_axis)。例如,如果你的最后一个卷积层有 64 个过滤器,它将把 (16, 7, 7, 64) 变成 (16, 64)。让我们在一些卷积操作之后进行测试:
import tensorflow as tf
x = tf.cast(
tf.random.uniform(shape=(16, 28, 28, 3), minval=0, maxval=256, dtype=tf.int32),
tf.float32)
gap = tf.keras.layers.GlobalAveragePooling2D()
for i in range(5):
conv = tf.keras.layers.Conv2D(64, 3)
x = conv(x)
print(x.shape)
print(gap(x).shape)
(16, 24, 24, 64)
(16, 22, 22, 64)
(16, 20, 20, 64)
(16, 18, 18, 64)
(16, 16, 16, 64)
(16, 64)
你应该使用哪个?
Flatten 层将始终具有至少与GlobalAveragePooling2D 层一样多的参数。如果展平前的最终张量形状仍然很大,例如(16, 240, 240, 128),则使用Flatten 将产生大量参数:240*240*128 = 7,372,800。这个巨大的数字将乘以下一个密集层中的单元数!那时,GlobalAveragePooling2D 在大多数情况下可能是首选。如果你使用MaxPooling2D 和Conv2D 如此之多以至于你在展平之前的张量形状就像(16, 1, 1, 128),它不会有什么不同。如果你过度拟合,你可能想试试GlobalAveragePooling2D。