【问题标题】:fft algorithm yields imprecise resultsfft 算法产生不精确的结果
【发布时间】:2017-08-02 13:59:29
【问题描述】:

我正在尝试基于 dft(离散傅立叶变换)矩阵分解来实现 fft(快速傅立叶变换)。在以下代码中,fft 和直接方法(即:将 dft 矩阵直接与 v 相乘)都是为了测试我的fft实现的有效性而实现的。

import numpy as n
import cmath, math
import matplotlib.pyplot as plt

v=n.array([1,-1,2,-3])
w=v
N=len(v)
t=[0]*N
M=n.zeros((N,N),dtype=complex)
z=n.exp(2j*math.pi/N)
for a in range(N):
    for b in range(N):
        M[a][b]=n.exp(2j*math.pi*a*b/N)
print (n.dot(v,M))
plt.plot(n.dot(v,M))
def f(x):
    x=n.concatenate([x[::2],x[1::2]])
    return x

while (w!=f(v)).any():
    v=f(v)
print(v)
a=2
while a<=N:

    for k in range(N/a):
        for y in range(a/2):
            t[y]=v[a*k+y]
        for i in range(a/2):
            v[a*k+i]+=v[a*k+i+a/2]*(z**i)
            v[a*k+i+a/2]=t[i]-v[a*k+i+a/2]*(z**i)
    a*=2    
print(v)
plt.plot(v)

plt.show()

我已经尝试了很多 v 的值,有时这两种方法的输出产生完全相同的结果,但有时它们彼此接近但不完全相同。经过几次不同的 v 值测试后,它们还没有远离彼此。

有什么我遗漏的东西会导致代码不精确吗?

编辑: 请注意,代码是为 Python 2 设计的(因为隐式整数除法)。

【问题讨论】:

  • 你在 python 2 上运行吗?
  • 不要忽视警告。 ComplexWarning:将复数转换为实数会丢弃虚部v[a*k+i+a/2]=t[i]-v[a*k+i+a/2]*(z**i) 敲响警钟? (提示:您为M 设置了正确的dtype。也请尝试为v。)
  • 看来解决方案其实在v的声明中(感谢@kazemakase)。请改用v=n.array([1,-1,2,-3], dtype=complex)。至少对我来说,曲线比彼此重叠。
  • @ThomasKühn 嗨,Thomas,请在您的回答下查看我的评论。
  • @kazemakase 这适用于我在帖子中显示的内容,但不适用于 v=[1,2,3,4,5,6,7,8] 并且警告仍然显示在我为 v 设置了 dtype 之后。

标签: numpy matplotlib fft dft cmath


【解决方案1】:

看来问题不在于算法,而在于 v 的声明(感谢@kazemakase)。试试

v=n.array([1,-1,2,-3], dtype=complex) 

相反。至少对我来说,曲线会出现在彼此之上:

编辑

这真是一段旅程。我无法弄清楚您的代码有什么问题,但看起来 dft 和 fft 都有几个错误。最后,我根据 [this document] (http://www.cs.cmu.edu/afs/andrew/scs/cs/15-463/2001/pub/www/notes/fourier/fourier.pdf) 编写了自己的 fft 版本(第 6 - 9 页包含您需要的所有信息)。也许您可以通过算法找出问题所在。位反转的算法可以在this answer(或者this one)中找到。我测试了不同长度的线性向量的代码——如果你发现任何错误,请告诉我。

import numpy as np
import cmath

def bit_reverse(x,n):
    """
    Reverse the last n bits of x
    """

    ##from https://stackoverflow.com/a/12682003/2454357
    ##formstr = '{{:0{}b}}'.format(n)
    ##return int(formstr.format(x)[::-1],2)

    ##from https://stackoverflow.com/a/5333563/2454357
    return sum(1<<(n-1-i) for i in range(n) if x>>i&1)

def permute_vector(v):
    """
    Permute vector v such that the indices of the result
    correspond to the bit-reversed indices of the original.
    Returns the permuted input vector and the number of bits used.
    """
    ##check that len(v) == 2**n
    ##and at the same time find permutation length:
    L = len(v)
    comp = 1
    bits = 0
    while comp<L:
        comp *= 2
        bits += 1
    if comp != L:
        raise ValueError('permute_vector: wrong length of v -- must be 2**n')
    rindices = [bit_reverse(i,bits)for i in range(L)]
    return v[rindices],bits

def dft(v):
    N = v.shape[0]
    a,b = np.meshgrid(
        np.linspace(0,N-1,N,dtype=np.complex128),
        np.linspace(0,N-1,N,dtype=np.complex128),
    )
    M = np.exp((-2j*np.pi*a*b)/N)

    return np.dot(M,v)


def fft(v):
    w,bits = permute_vector(v)
    N = w.shape[0]
    z=np.exp(np.array(-2j,dtype=np.complex128)*np.pi/N)

    ##starting fft
    for i in range(bits): 
        dist = 2**i  ##distance between 'exchange pairs'
        group = dist*2 ##size of sub-groups
        for start in range(0,N,group):
            for offset in range(group//2):
                pos1 = start+offset
                pos2 = pos1+dist
                alpha1 = z**((pos1*N//group)%N)
                alpha2 = z**((pos2*N//group)%N)
                w[pos1],w[pos2] = w[pos1]+alpha1*w[pos2],w[pos1]+alpha2*w[pos2]
    return w

if __name__ == '__main__':

    #test the fft
    for n in [2**i for i in range(1,5)]:
        print('-'*25+'n={}'.format(n)+'-'*25)
        v = np.linspace(0,n-1,n, dtype=np.complex128)
        print('v = ')
        print(v)
        print('fft(v) = ')
        print(fft(v))
        print('dft(v) = ')
        print(dft(v))
        print('relative error:')
        print(abs(fft(v)-dft(v))/abs(dft(v)))

这给出了以下输出:

-------------------------n=2-------------------------
v = 
[ 0.+0.j  1.+0.j]
fft(v) = 
[ 1. +0.00000000e+00j -1. -1.22464680e-16j]
dft(v) = 
[ 1. +0.00000000e+00j -1. -1.22464680e-16j]
relative error:
[ 0.  0.]
-------------------------n=4-------------------------
v = 
[ 0.+0.j  1.+0.j  2.+0.j  3.+0.j]
fft(v) = 
[ 6. +0.00000000e+00j -2. +2.00000000e+00j -2. -4.89858720e-16j
 -2. -2.00000000e+00j]
dft(v) = 
[ 6. +0.00000000e+00j -2. +2.00000000e+00j -2. -7.34788079e-16j
 -2. -2.00000000e+00j]
relative error:
[  0.00000000e+00   0.00000000e+00   1.22464680e-16   3.51083347e-16]
-------------------------n=8-------------------------
v = 
[ 0.+0.j  1.+0.j  2.+0.j  3.+0.j  4.+0.j  5.+0.j  6.+0.j  7.+0.j]
fft(v) = 
[ 28. +0.00000000e+00j  -4. +9.65685425e+00j  -4. +4.00000000e+00j
  -4. +1.65685425e+00j  -4. -7.10542736e-15j  -4. -1.65685425e+00j
  -4. -4.00000000e+00j  -4. -9.65685425e+00j]
dft(v) = 
[ 28. +0.00000000e+00j  -4. +9.65685425e+00j  -4. +4.00000000e+00j
  -4. +1.65685425e+00j  -4. -3.42901104e-15j  -4. -1.65685425e+00j
  -4. -4.00000000e+00j  -4. -9.65685425e+00j]
relative error:
[  0.00000000e+00   6.79782332e-16   7.40611132e-16   1.85764404e-15
   9.19104080e-16   3.48892999e-15   3.92837008e-15   1.35490975e-15]
-------------------------n=16-------------------------
v = 
[  0.+0.j   1.+0.j   2.+0.j   3.+0.j   4.+0.j   5.+0.j   6.+0.j   7.+0.j
   8.+0.j   9.+0.j  10.+0.j  11.+0.j  12.+0.j  13.+0.j  14.+0.j  15.+0.j]
fft(v) = 
[ 120. +0.00000000e+00j   -8. +4.02187159e+01j   -8. +1.93137085e+01j
   -8. +1.19728461e+01j   -8. +8.00000000e+00j   -8. +5.34542910e+00j
   -8. +3.31370850e+00j   -8. +1.59129894e+00j   -8. +2.84217094e-14j
   -8. -1.59129894e+00j   -8. -3.31370850e+00j   -8. -5.34542910e+00j
   -8. -8.00000000e+00j   -8. -1.19728461e+01j   -8. -1.93137085e+01j
   -8. -4.02187159e+01j]
dft(v) = 
[ 120. +0.00000000e+00j   -8. +4.02187159e+01j   -8. +1.93137085e+01j
   -8. +1.19728461e+01j   -8. +8.00000000e+00j   -8. +5.34542910e+00j
   -8. +3.31370850e+00j   -8. +1.59129894e+00j   -8. -6.08810394e-14j
   -8. -1.59129894e+00j   -8. -3.31370850e+00j   -8. -5.34542910e+00j
   -8. -8.00000000e+00j   -8. -1.19728461e+01j   -8. -1.93137085e+01j
   -8. -4.02187159e+01j]
relative error:
[  0.00000000e+00   1.09588741e-15   1.45449990e-15   6.36716793e-15
   8.53211992e-15   9.06818284e-15   1.30922044e-14   5.40949529e-15
   1.11628436e-14   1.23698141e-14   1.50430426e-14   3.02428869e-14
   2.84810617e-14   1.16373983e-14   1.10680934e-14   3.92841628e-15]

这是一个很好的挑战——我学到了很多东西!您可以在线验证代码的结果,例如here

【讨论】:

  • 感谢您的建议。我按照你说的做了,它适用于你展示的那个,但它并不适用于 [1,2,3,4,5,6,7,8]
  • @pxc3110 你有你尝试实现的算法的链接吗?
  • 基本思想是离散傅立叶矩阵 exp(2jpia*b/N) 可以分解为单位矩阵、对角矩阵和奇数置换矩阵。
  • 非常感谢!我喜欢你学到了一些东西,这是这个网站的重点。以下链接提供了实际实现而不是理论部分:wiki.python.org/moin/NumericAndScientificRecipes 我想你可能也想看看这个?
猜你喜欢
  • 2020-06-18
  • 2021-03-01
  • 2019-08-02
  • 1970-01-01
  • 2020-12-14
  • 2015-10-12
  • 2015-01-06
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多