【问题标题】:Tensorflow: reshape [N,H,W,C] to [N*C,H,W,1] for convolution per channelTensorflow:将 [N,H,W,C] 重塑为 [N*C,H,W,1] 以进行每个通道的卷积
【发布时间】:2026-02-22 03:15:01
【问题描述】:

我想要实现的是使用一个过滤器应用 2D 卷积,该过滤器应用于所有通道。请注意,我不是在寻找深度卷积,而是真正的一个过滤器。为了做到这一点,我的计划是将[N,H,W,C]重塑为[N*C,H,W,1],应用卷积,然后重新整形,所以我的输出又是[N,H,W,C]

_, self.h, self.w, self.c = inputs.shape
self.conv = tf.keras.layers.Conv2D(filters=2, kernel_size=3, strides=1, padding='same')
x = tf.reshape(inputs, [-1,self.h,self.w,1])
x = self.conv(x)
x = tf.math.argmax(x, axis=3)
output = tf.reshape(x ,[-1,self.h,self.w,self.c])

然而,在实现这一点时,我注意到第一个 reshape 的输出包含通道或批次或其他东西之间的某种交错(图像来自 ImageNet):Before reshapeAfter reshape。我的直觉是,这可能是因为批处理和通道在内存中并不相邻。

出于这个原因,我通过先转置输入进行实验,然后再应用整形、卷积、整形和转置:

_, self.h, self.w, self.c = inputs.shape
self.conv = tf.keras.layers.Conv2D(filters=2, kernel_size=3, strides=1, padding='same', data_format="channels_first")
x_t = tf.transpose(inputs, [0,3,1,2]) # convert nhwc to nchw
x_t = tf.reshape(x_t, [-1,1,self.h,self.w])
x_t = self.conv(x_t)
x_t = tf.math.argmax(x_t, axis=1)
x_t = tf.reshape(x_t ,[-1,self.c,self.h,self.w])
output = tf.transpose(x_t, [0,2,3,1])

这似乎确实像我预期的那样工作,但这是一种相当缓慢的方法。 我有一系列问题:

  1. 我遇到的交错模式的确切原因是什么?
  2. 是否有一种无需使用转置即可重塑数据的方法?我知道我也许可以在任何地方使用NCHW 数据格式,但是当我尝试在现有平台上构建实现时,我认为更改数据格式会破坏代码的其他部分。
  3. 是否可以采用完全不同的方法对每个通道应用卷积?我曾考虑过使用 unstack 或其他东西,但这需要 for 循环,这在我的想法中效率更低。

提前致谢

edit:我想我至少明白为什么会发生交错。让我试着用我的理解来解释它,并尽可能地格式化它。这里的字母 n,h,w,c 应该有助于识别一个数字属于什么。

假设我有 16 个在内存中连续的数字:

1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16

如果它们的形状为NHWC: 2,2,2,2

[n[h[w[c1,c2],
     w[c3,c4]], 
   h[w[c5,c6],
     w[c7,c8]]], 
 n[h[w[c9,c10],
     w[c11,c12]], 
   h[w[c13,c14],
     w[c15,c16]]]]

然后,如果将它们重新整形为 NHWC: 4,2,2,1 并通过保持基础数据连续,我们得到:

[n[h[w[c1],
     w[c2]], 
   h[w[c3],
     w[c4]]], 
 n[h[w[c5],
     w[c6]], 
   h[w[c7],
     w[c8]]], 
 n[h[w[c9],
     w[c10]], 
   h[w[c11],
     w[c12]]], 
 n[h[w[c13],
     w[c14]], 
   h[w[c15],
     w[c16]]]]

通过这种方式,通道在图像的空间维度中混合在一起。

【问题讨论】:

    标签: tensorflow reshape convolution


    【解决方案1】:

    首先,如果您想合并非连续维度,reshape 将无法正常工作,因此转置是个好主意。然而 Conv2d 期望最后一个维度是输入通道。在您的情况下,它是self.w,但它可能是您想要的。您可以像以前一样将重塑线更改为:tf.reshape(x_t, [-1, self.h, self.w, 1])

    【讨论】:

    • 确实,我似乎了解重塑方面正在发生的事情。我试着在我的帖子的编辑中写出来。第二个版本(带有转置)实际上是与dataformat="channels_first" 进行卷积以处理第二维中的通道。
    • 哦,我没明白