序
这里不会告诉你复杂的公式,只会告诉你怎么由将多个点的傅里叶变换转化为少数点的傅里叶变换,并且不按照预定好的反二进制序输入,仅仅需要按照自然数顺序输入即可。同时我会以实部和虚部的方式实现,而不是直接使用复数。这在C编程时更有优势。
图1 8点DFT分解成两点DFT的实现图
我所根据的便是以上这张图,我们可以看到输入其实是按照预定好的反向二进制序输入的:[0, 4, 2, 6]和[1, 5, 3 7]。而我们的输入一般是按照[0, 1, 2, 3, 4, 5, 6, 7]来输入的。
我们的输入可以进行变换:
[0, 1, 2, 3, 4, 5, 6, 7]->reshape(2, 4)
[[0, 1, 2, 3], [4, 5, 6, 7]] ->transpose(-1,-2)
[[0, 4], [1, 5], [2, 6], [3, 7]]
// 注相比反向二进制序,这种方式在NPU上更利于实现,如果NPU的矩阵乘加单元一次可以完成[16, 16] * [16, 16]那么我们只需要将输入做这样的变换即可,例如2048点[2048]->reshape->[16, 128]->transpose->[128, 16]然后我们就可以乘系数矩阵了。
基点的傅里叶变换
class IDFT(nn.Module):
def __init__(self, N = 2, back=2):
super(IDFT, self).__init__()
self.N = N
self.back = back
self.n1 = torch.arange(0, self.N)
self.k1 = torch.arange(0, self.N).view(-1, 1)
kn1 = self.k1 * self.n1
self.w1_real = torch.cos(kn1 * 2 * np.pi / self.N) / self.back
self.w1_imag = torch.sin(kn1 * 2 * np.pi / self.N) / self.back
def forward(self, x_real, x_imag):
y_real = torch.matmul(x_real, self.w1_real) - torch.matmul(x_imag, self.w1_imag)
y_imag = torch.matmul(x_real, self.w1_imag) + torch.matmul(x_imag, self.w1_real)
return y_real, y_imag
首先生成系数:
kn1 = self.k1 * self.n1
这个会生成一个N*N的矩阵,形式如下(例子为16点):
[ 0, 0, 0, ..., 0, 0]
[ 0, 1, 2, ..., 14, 15]
[ 0, 2, 4, ..., 28, 30]
[..., ..., ..., ..., ..., ...]
[..., ..., ..., ..., ..., ...]
[ 0, 14, 28, ..., 196, 210]
[ 0, 15, 30, ..., 210, 225]
然后根据此系数生成FFT或IFFT系数的实部和虚部(代码为IFFT),forward
函数为完成基点的傅里叶变换。
基于基点的傅里叶变换的分解
在序中可以看到我们的输入是按照正向的顺序,即[[0, 4], [1, 5], [2, 6], [3, 7]], 那么在合并时,序号[2, 6], [3, 7]需要乘[W_n^0,W_N^2],即对做完基点的输出矩阵后半段乘对应的系数,合并时按照最后一个轴进行合并。
class IFFT(nn.Module):
def __init__(self, N = 2, base = 2):
super(IFFT, self).__init__()
self.N = N
self.base = base
self.n = torch.arange(0, self.N)[ : self.N // 2]
self.w2_real = torch.cos(self.n * 2 * np.pi / self.N)
self.w2_imag = torch.sin(self.n * 2 * np.pi / self.N)
self.IDFT = IDFT(N=base, back=self.N)
def forward(self, x_real, x_imag):
x_real = x_real.reshape(-1, self.base, self.N // self.base).transpose(-1, -2)
x_imag = x_imag.reshape(-1, self.base, self.N // self.base).transpose(-1, -2)
x_real, x_imag = self.IDFT(x_real, x_imag)
row = self.N // self.base
while row != 1:
x_real_temp = x_real[ : , row // 2 :, : ] * self.w2_real[::row // 2] \
- x_imag[ : , row // 2 :, : ] * self.w2_imag[::row // 2]
x_imag_temp = x_real[ : , row // 2 :, : ] * self.w2_imag[::row // 2] \
+ x_imag[ : , row // 2 :, : ] * self.w2_real[::row // 2]
x_merge_1_real = x_real[ : , : row // 2, : ] + x_real_temp
x_merge_1_imag = x_imag[ : , : row // 2, : ] + x_imag_temp
x_merge_2_real = x_real[ : , : row // 2, : ] - x_real_temp
x_merge_2_imag = x_imag[ : , : row // 2, : ] - x_imag_temp
x_real = torch.cat([x_merge_1_real, x_merge_2_real], dim=-1)
x_imag = torch.cat([x_merge_1_imag, x_merge_2_imag], dim=-1)
row = row >> 1
return x_real, x_imag
评论区