侧边栏壁纸
博主头像
Enjoyably博主等级

独行快,众行远,Walk Together!

  • 累计撰写 5 篇文章
  • 累计创建 3 个标签
  • 累计收到 0 条评论

目 录CONTENT

文章目录

快速傅里叶变换的手动实现

Relight
2025-06-02 / 0 评论 / 0 点赞 / 17 阅读 / 4984 字

这里不会告诉你复杂的公式,只会告诉你怎么由将多个点的傅里叶变换转化为少数点的傅里叶变换,并且不按照预定好的反二进制序输入,仅仅需要按照自然数顺序输入即可。同时我会以实部和虚部的方式实现,而不是直接使用复数。这在C编程时更有优势。

图1 8点DFT分解成两点DFT的实现图

我所根据的便是以上这张图,我们可以看到输入其实是按照预定好的反向二进制序输入的:[0, 4, 2, 6]和[1, 5, 3 7]。而我们的输入一般是按照[0, 1, 2, 3, 4, 5, 6, 7]来输入的。

我们的输入可以进行变换:

[0, 1, 2, 3, 4, 5, 6, 7]->reshape(2, 4)
[[0, 1, 2, 3], [4, 5, 6, 7]] ->transpose(-1,-2)
[[0, 4], [1, 5], [2, 6], [3, 7]]
// 注相比反向二进制序,这种方式在NPU上更利于实现,如果NPU的矩阵乘加单元一次可以完成[16, 16] * [16, 16]那么我们只需要将输入做这样的变换即可,例如2048点[2048]->reshape->[16, 128]->transpose->[128, 16]然后我们就可以乘系数矩阵了。

基点的傅里叶变换

class IDFT(nn.Module):
    def __init__(self, N = 2, back=2):
        super(IDFT, self).__init__()
        self.N = N
        self.back = back
        self.n1 = torch.arange(0, self.N)
        self.k1 = torch.arange(0, self.N).view(-1, 1)
        kn1 = self.k1 * self.n1
        self.w1_real = torch.cos(kn1 * 2 * np.pi / self.N) / self.back
        self.w1_imag = torch.sin(kn1 * 2 * np.pi / self.N) / self.back
    def forward(self, x_real, x_imag):
        y_real = torch.matmul(x_real, self.w1_real) - torch.matmul(x_imag, self.w1_imag)
        y_imag = torch.matmul(x_real, self.w1_imag) + torch.matmul(x_imag, self.w1_real)
        return y_real, y_imag

首先生成系数:

kn1 = self.k1 * self.n1 这个会生成一个N*N的矩阵,形式如下(例子为16点):

[  0,   0,   0, ...,   0,   0]
[  0,   1,   2, ...,  14,  15]
[  0,   2,   4, ...,  28,  30]
[..., ..., ..., ..., ..., ...]
[..., ..., ..., ..., ..., ...]
[  0,  14,  28, ..., 196, 210]
[  0,  15,  30, ..., 210, 225]

然后根据此系数生成FFT或IFFT系数的实部和虚部(代码为IFFT),forward 函数为完成基点的傅里叶变换。

基于基点的傅里叶变换的分解

在序中可以看到我们的输入是按照正向的顺序,即[[0, 4], [1, 5], [2, 6], [3, 7]], 那么在合并时,序号[2, 6], [3, 7]需要乘[W_n^0,W_N^2],即对做完基点的输出矩阵后半段乘对应的系数,合并时按照最后一个轴进行合并。

class IFFT(nn.Module):
    def __init__(self, N = 2, base = 2):
        super(IFFT, self).__init__()
        self.N = N
        self.base = base
        self.n = torch.arange(0, self.N)[ : self.N // 2]
        self.w2_real = torch.cos(self.n * 2 * np.pi / self.N)
        self.w2_imag = torch.sin(self.n * 2 * np.pi / self.N)
        self.IDFT = IDFT(N=base, back=self.N)

    def forward(self, x_real, x_imag):
        x_real = x_real.reshape(-1, self.base, self.N // self.base).transpose(-1, -2)
        x_imag = x_imag.reshape(-1, self.base, self.N // self.base).transpose(-1, -2)
        x_real, x_imag = self.IDFT(x_real, x_imag)
        row = self.N // self.base
        while row != 1:       
            x_real_temp = x_real[ : , row // 2 :, : ] * self.w2_real[::row // 2] \
                          - x_imag[ : , row // 2 :, : ] * self.w2_imag[::row // 2]
            
            x_imag_temp = x_real[ : , row // 2 :, : ] * self.w2_imag[::row // 2] \
                          + x_imag[ : , row // 2 :, : ] * self.w2_real[::row // 2]

            x_merge_1_real = x_real[ : , : row // 2, : ] + x_real_temp
            x_merge_1_imag = x_imag[ : , : row // 2, : ] + x_imag_temp
            
            x_merge_2_real = x_real[ : , : row // 2, : ] - x_real_temp
            x_merge_2_imag = x_imag[ : , : row // 2, : ] - x_imag_temp

            x_real = torch.cat([x_merge_1_real, x_merge_2_real], dim=-1)
            x_imag = torch.cat([x_merge_1_imag, x_merge_2_imag], dim=-1)
            row = row >> 1

        return x_real, x_imag

0

评论区