【问题标题】:Implement 2D convolution using FFT使用 FFT 实现 2D 卷积
【发布时间】:2021-03-29 10:25:34
【问题描述】:

TensorFlow.conv2d() 将大图像与大内核(过滤器)进行卷积非常慢。将 1024x1024 图像与相同大小的内核进行卷积需要几分钟。为了比较,cv2.filter2D() 立即返回结果。

我找到了tf.fft2()tf.rfft()

但是我不清楚如何使用这些函数执行简单的图像过滤。

如何使用 TensorFlow 使用 FFT 实现快速的 2D 图像过滤?

【问题讨论】:

  • 我不知道 Tensorflow,但我认为你可以使用 convolution theorem 来做到这一点。
  • tf.nn.conv2d 优化得很好;您是否将图形构建时间包括在“几分钟”中?在这种情况下,您可以尝试tf.enable_eager_execution()。如果没有 Minimal, Complete, and Verifiable 的方式来重现您所询问的性能问题,很难说更多。
  • fft2 在张量流中不起作用。

标签: python image-processing tensorflow fft


【解决方案1】:

可以使用卷积定理和离散时间傅里叶变换 (DTFT) 计算 x * y 形式的线性离散卷积。如果x * y 是一个圆形离散卷积,则可以使用离散傅里叶变换 (DFT) 进行计算。

卷积定理状态x * y 可以使用傅里叶变换计算为

其中 表示傅里叶变换, 表示傅里叶逆变换。当xy 是离散的并且它们的卷积是线性卷积时,这是使用 DTFT 计算的

如果 xy 是离散的,并且它们的卷积是循环卷积,则上面的 DTFT 将被 DFT 替换。 注意:线性卷积问题可以嵌入到循环卷积问题中。


我更熟悉 MATLAB,但通过阅读 tf.signal.fft2dtf.signal.ifft2d 的 TensorFlow 文档,通过替换 MATLAB 函数 fft2ifft2,下面的解决方案应该可以轻松转换为 TensorFlow。

在 MATLAB(和 TensorFlow)中,fft2(和 tf.signal.fft2d)使用快速傅里叶变换算法计算 DFT。如果xy 的卷积是循环的,则可以通过

ifft2(fft2(x).*fft2(y))

其中.* 表示 MATLAB 中的逐元素乘法。但是,如果它是线性的,那么我们将数据零填充到长度2N-1,其中N 是一维的长度(问题中的 1024)。在 MATLAB 中,这可以通过以下两种方式之一进行计算。首先,由

h = ifft2(fft2(x, 2*N-1, 2*N-1).*fft2(y, 2*N-1, 2*N-1));

其中 MATLAB 通过零填充计算 2*N-1-point 二维傅里叶变换 xy,然后计算 2*N-1-point 二维傅里叶逆变换。此方法不能在 TensorFlow 中使用(根据我对文档的理解),因此下一个是唯一的选择。在 MATLAB 和 TensorFlow 中,可以通过首先将 xy 扩展到 2*N-1 x 2*N-1 的大小,然后计算 2*N-1 点 2D 傅里叶变换和傅里叶逆变换来计算卷积

x_extended = x;
x_extended(2*N-1, 2*N-1) = 0;

y_extended = y;
y_extended(2*N-1, 2*N-1) = 0;

h_extended = ifft2(fft2(x_extended).*fft2(y_extended));

在 MATLAB 中,hh_extended 完全相等。 xy 的卷积可以在没有傅里叶变换的情况下计算出

hC = conv2(x, y);

在 MATLAB 中。


在我笔记本电脑上的 MATLAB 中 conv2(x, y) 需要 55 秒,而傅里叶变换方法需要不到 0.4 秒。

【讨论】:

    【解决方案2】:

    这可以通过类似于实现 scipy.signal.fftconvolve 的方式来完成。

    这是一个示例,假设我们有一个图像(2 维,如果您还有多个通道,则可以使用 3d 而不是 2 个函数)(im)和一个过滤器(例如高斯)。

    首先,对图像进行傅里叶变换并定义fft_lenghts(如果过滤器具有不同的形状,则很有用,在这种情况下,它将被零填充。)

    fft_lenght1 = tf.shape(im)[0]
    fft_lenght2 = tf.shape(im)[1]
    im_fft = tf.signal.rfft2d(im, fft_length=[fft_lenght1, fft_lenght2])
    

    接下来,对滤波器进行 FFT(注意,例如,对于 2d 高斯滤波器,请确保中心位于左上角,即仅使用“四分之一”)

    kernel_fft = tf.signal.rfft2d(kernel, fft_length=[fft_lenght1, fft_lenght2])
    

    最后再进行逆变换得到卷积后的图像

    im_blurred = tf.signal.irfft2d(im_fft * kernel_fft, [fft_lenght1, fft_lenght2])
    

    【讨论】:

      猜你喜欢
      • 2014-05-18
      • 1970-01-01
      • 1970-01-01
      • 2011-03-06
      • 2012-12-10
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2011-12-30
      相关资源
      最近更新 更多