多项式笔记-快速傅里叶变换与快速数论变换

前言

快速傅里叶变换,用于在 O(nlogn) 的时间内解决两个n度多项式的乘法.

先介绍离散傅里叶变换,摘自百度百科:

离散傅里叶变换(DFT),是傅里叶变换在时域和频域上都呈现离散的形式,将时域信号的采样变换为在离散时间傅里叶变换(DTFT)频域的采样。在形式上,变换两端(时域和频域上)的序列是有限长的,而实际上这两组序列都应当被认为是离散周期信号的主值序列。即使对有限长的离散信号作DFT,也应当将其看作经过周期延拓成为周期信号再作变换。在实际应用中通常采用快速傅里叶变换以高效计算DFT。

多项式的表示

系数表示

用一个多项式各项的系数来表示一个多项式:

f(x)=a_0+a_1x+a_2x^2+a_3x^3+ \cdots+a_nx^n \leftrightarrow f(x)=\{a_0,a_1,a_2,a_3,\cdots,a_n\}

点值表示

将多项式视为函数,用在其上的n+1个点唯一表示一个函数的方法来表示一个多项式.

f(x)=a_0+a_1x+a_2x^2+a_3x^3+\cdots+a_nx^n \\ \leftrightarrow f(x)=\{(x_0,y_0),(x_1,y_1),(x_2,y_2,(x_3,y_3),\cdots,(x_n,y_n)\}

将多项式由系数表示转化为点值表示的过程即为DFT,将点值表示转化为系数表示的过程即IDFT(离散傅里叶逆变换),快速傅里叶变化即为通过取特殊的点带入点值表示来加速DFT和IDFT的过程.

单位复根

定义 x^n=1 在复数意义下的解为n的单位复根,这样的解有n个,设 \omega_n=e^\frac{2\pi i}{n} ,那么 x^n=1 的解集为 \{\omega_n^k|k=0,1,\cdots,n-1\} , \omega_n^i 在复数平面上表示为单位圆n等分后按极角排序后第i个角对应的向量,根据复数乘法模长相乘,辐角相加的定义,n的所有单位复根都可以用 \omega_n^1 的幂来表示.

然后根据欧拉公式,有 \omega_n=e^\frac{2\pi i}{n}=cos(\frac{2\pi i}{n})+i\cdot sin(\frac{2\pi i}{n}) .

对于任意正整数n和整数k,单位复根有如下性质:

\omega_n^k=\omega_{2n}^{2k} \\ \omega_n^n=1 \\ \omega_n^{k+n/2}=-\omega_{n}^k

这三条性质在之后的推导中尤为重要.

快速傅里叶变换

DFT

现在考虑将两个多项式进行DFT:

f(x)=\{(x_0,f(x_0)),(x_1,f(x_1)),(x_2,f(x_2)),\cdots,(x_n,f(x_n))\} \\ g(x)=\{(x_,g(x_0)),(x_1,g(x_1)),(x_2,g(x_2)),\cdots,(x_n,g(x_n))\}

F(x)=f(x)\cdot g(x) ,很容易就得到 F(x) 的点值表示:

F(x)=\{(x_0,f(x_0)\cdot g(x_0)),(x_1,f(x_1)\cdot g(x_1)),(x_2,f(x_2)\cdot g(x_2)),\cdots,(x_n,f(x_n)\cdot g(x_n))\}

但是要求的是 F(x) 的系数表示,我们需要在 O(n^2) 的时间内算出每一个 x^i ,然后消元这n+1个方程,复杂度显然过高.

所以不能直接地去计算 x^i 的值.考虑如果x是1和-1的话幂次就很容易得到了,但是我们需要n+1个点,所以就引入了刚刚提到的单位复根,这些数无论怎么乘长度都为1.

FFT基于分治,分为DFT和IDFT两个部分,先讨论DFT,它分治地求出 x=\omega_n^i 时的值.

分治基于把一个多项式分为奇数项和偶数项处理,以一个8项多项式为例:

f(x)=a_0+a_1x+a_2x^2+a_3x^3+a_4x^4+a_5x^5+a_6x^6+a_7x^7  \\ =(a_0+a_2x^2+a_4x^4+a_6x^6)+(a_1x+a_3x^3+a_5x^5+a_7x^7)  \\ =(a_0+a_2x^2+a_4x^4+a_6x^6)+x(a_1+a_3x^2+a_5x^4+a_7x^6)

按照项的奇偶建立新的多项式:

H(x)=(a_0+a_2x^2+a_4x^4+a_6x^6) \\ G(x)=x(a_1+a_3x^2+a_5x^4+a_7x^6)

原来的 f(x) 就可以表示为:

F(x)=H(x^2)+xG(x^2)

带入 \omega_n^k 进行DFT,得到:

DFT(F(\omega_n^k))=DFT(H((\omega_n^k)^2))+\omega_n^kDFT(G((\omega_n^k)^2))  \\ =DFT(H(\omega_n^{2k}))+\omega_n^kDFT(G(\omega_n^{2k}))  \\ =DFT(H(\omega_{n/2}^k))+\omega_n^kDFT(G(\omega_{n/2}^k))

以及:

DFT(F(\omega_n^{k+n/2}))=DFT(H((\omega_n^{k+n/2})^2))+\omega_n^{k+n/2}DFT(G((\omega_n^{k+n/2})^2)) \\ =DFT(H(\omega_n^{2k}))+\omega_n^{k+n/2}DFT(G(\omega_n^{2k})) \\ =DFT(H(\omega_{n/2}^k))-\omega_n^{k}DFT(G(\omega_{n/2}^k))
所以我们可以递归地算出 DFT(H(\omega_{n/2}^k))DFT(\omega_n^{k}(G(\omega_{n/2}^k)) 这两项,同时得到 DFT(F(\omega_n^k))DFT(F(\omega_n^{k+n/2})) .

考虑分治处理的时候必须在两边都取到系数,我们把进行递归DTF的函数补成2的次幂项的多项式.

IDFT

我们用DFT获得了 f(x)g(x) 的点值表示,然后将其按位相乘得到了 F(x) 的点值表示,现在考虑如何得到 F(x) 的系数表示.

y_i=F(\omega_n^i) ,构造一个多项式 G(x)=\sum_{i=0}^{n-1}y_ix^i .

b_i=\omega_n^{-i} 则可以将该多项式表示为:

G(b_k)=\sum_{i=0}^{n-1}F(\omega_n^i)\omega_n^{-ik}  \\ =\sum_{i=0}^{n-1}\omega_n^{-ik}\sum_{j=0}^{n-1}a_j(\omega_n^i)^j  \\ =\sum_{i=0}^{n-1}\omega_n^{-ik}\sum_{j=0}^{n-1}a_j\omega_n^{ij} \\ =\sum_{i=0}^{n-1}\sum_{j=0}^{n-1}a_j\omega_n^{i(j-k)} \\ =\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i

考虑后面的 \sum_{i=0}^{n-1}(\omega_n^{j-k})^i 的求和,设 s(\omega_n^k)=\sum_{i=0}^{n-1}(\omega_n^k)^i .

分类讨论,当 k=0 时,显然 s(\omega_n^k)=n .

k \neq 0 时,考虑等比数列求和:

s(\omega_n^k)=\sum_{i=0}^{n-1}(\omega_n^k)^i \\ \omega_n^ks(\omega_n^k)=\sum_{i=1}^n(\omega_n^k)^i \\ (\omega_n^k-1)s(\omega_n^k)=(\omega_n^k)^n-(\omega_n^k)^0 \\ s(\omega_n^k)=\frac{\omega_n^{kn}-1}{\omega_n^k-1} \\ =\frac{(\omega_n^n)^k-1}{\omega_n^k-1} \\ =\frac{1-1}{\omega_n^k-1} \\ =0

所以有:

\sum_{i=0}^{n-1}(\omega_n^k)^i=\left\{\begin{split}n,a=0\\0,a\neq0 \end{split}\right.

带人 G(b_k)=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^{j-k})^i 得:

G(b_k)=\sum_{j=0}^{n-1}a_js(\omega_n^{j-k}) \\ =a_k\cdot n

所以 G(x)x=b_0,b_1,b_2,\cdots,b_n 处取到的点值表示为:

\{(b_0,a_0\cdot n),(b_1,a_1 \cdot n),(b_2,a_2 \cdot n),\cdots,(b_n,a_n \cdot n)\}

把DFT中的 \omega_n^i 换成 \omega_n^{-i} 算出G(x) 的点值表示再将每项除以n就得到 F(x) 的系数表示了.

我们有 \omega_n^{-1}=cos(\frac{2\pi i}{n})+i\cdot sin(-\frac{2\pi i}{n}) ,所以只需稍微修改一下DFT的函数即可实现.

代码实现,计算两个多项式的卷积:

#include "iostream"
#include "cstdio"
#include "cstring"
#include "cstdlib"
#include "cmath"
#include "cctype"
#include "algorithm"
#include "queue"
#include "stack"
#include "set"
#include "map"
#include "vector"
#define R register
#define INF 2147483647
#define debug(x) printf("debug:%lld\n",x)
using namespace std;
const int maxn=1350000;
const double pi=acos(-1);
int n,m;
struct CP
{
	double x,y;
	CP(double x_=0,double y_=0){x=x_,y=y_;}
	CP operator +(CP const &another)const
	{return CP(x+another.x,y+another.y);}
	CP operator -(CP const &another)const
	{return CP(x-another.x,y-another.y);}
	CP operator *(CP const &another)const
	{return CP(x*another.x-y*another.y,y*another.x+another.y*x);}
}f[maxn<<1],g[maxn<<1],tmp[maxn<<1];
void DFT(CP *f,int len,int IDFT)
{
	if(len==1)return;
	for(int i(0);i<len;tmp[i]=f[i],++i);
	for(int i(0);i<len;++i)
	    if(i&1)f[len/2+i/2]=tmp[i];
	    else f[i/2]=tmp[i];
	CP *h=f,*g=f+len/2;
	DFT(h,len/2,IDFT),DFT(g,len/2,IDFT);
	CP cur=(CP){1,0},buf=(CP){cos(2*pi/len),sin(2*pi*IDFT/len)};
	for(int k(0);k<len/2;++k)
	{
		tmp[k]=h[k]+cur*g[k];
		tmp[k+len/2]=h[k]-cur*g[k];
		cur=cur*buf;
	}
	for(int i(0);i<len;f[i]=tmp[i],++i);
}
int main(void)
{
	n=read(),m=read();
	for(R int i=0;i<=n;++i)scanf("%lf",&f[i].x);
	for(R int i=0;i<=m;++i)scanf("%lf",&g[i].x);
	for(m+=n,n=1;n<=m;n<<=1);
	DFT(f,n,1),DFT(g,n,1);
	for(R int i=0;i<n;++i)f[i]=f[i]*g[i];
	DFT(f,n,-1);
	for(R int i=0;i<=m;++i)printf("%lld ",(int)(f[i].x/n+0.49));
	return 0;
}

时间复杂度 O(nlogn).

蝴蝶变换

我们可以仿照递归的方式将这个递归的过程进行优化.

还是以一个8项多项式为例,模拟一下递归的过程:

\{x_0,x_1,x_2,x_3,x_4,x_5,x_6,x_7\} \\  \rightarrow \{x_0,x_2,x_4,x_6\},\{x_1,x_3,x_5,x_7\} \\ \rightarrow \{x_0,x_4\},\{x_2,x_6\},\{x_1,x_5\},\{x_3,x_7\} \\ \rightarrow \{x_0\},\{x_4\},\{x_2\},\{x_6\},\{x_1\},\{x_5\},\{x_3\},\{x_7\}

发现规律,将原序列上每一个位置进行二进制表示,翻转之后得到变换后的位置.

比如 (110)_2 ,翻转后得到 (011)_2 ,将6变换到了3号位.

现在考虑如何递推得到翻转后的值,因为从小到大地来求,所以在计算一个数x的表示时, \lfloor \frac{x}{2} \rfloor 的表示已经确定了,所以我们取到x二进制右移一位的表示,再将其右移一位,就得到了x除开第一位的翻转表示.

考虑第一位在翻转之后到了最高位,如果x为偶数,二进制下第一位为0,不影响翻转,为奇数,第一位为1,把最高位置为1即可.

代码实现:

for(int i=0;i<n;++i)tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
for(int i=0;i<n;++i)if(i<tr[i])swap(f[i],f[tr[i]]);

这样就得到了最终的代码实现,省去了递归过程:

#include "iostream"
#include "cstdio"
#include "cstring"
#include "cstdlib"
#include "cmath"
#include "cctype"
#include "algorithm"
#include "queue"
#include "stack"
#include "set"
#include "map"
#include "vector"
#define R register
#define INF 2147483647
#define debug(x) printf("debug:%lld\n",x)
using namespace std;
const int maxn=1350000;
const double pi=acos(-1);
int n,m;
int tr[maxn<<1];
struct CP
{
	double x,y;
	CP(double x_=0,double y_=0){x=x_,y=y_;}
	CP operator +(CP const &another)const
	{return CP(x+another.x,y+another.y);}
	CP operator -(CP const &another)const
	{return CP(x-another.x,y-another.y);}
	CP operator *(CP const &another)const
	{return CP(x*another.x-y*another.y,y*another.x+another.y*x);}
}f[maxn<<1],g[maxn<<1],tmp[maxn<<1];
inline int read()
{
	char c=getchar();
	int f=1,x=0;
	for(;!isdigit(c);c=getchar())(c=='-')&&(f=-1);
	for(;isdigit(c);c=getchar())x=(x<<1)+(x<<3)+(c^48);
	return f*x;
}
inline void FFT(CP *f,int IDFT)
{
	for(R int i(0);i<n;++i)if(i<tr[i])swap(f[i],f[tr[i]]);
	for(R int p(2);p<=n;p<<=1)
	{
		CP buf=(CP){cos(2*pi/p),sin(2*pi*IDFT/p)};
		for(R int k(0);k<n;k+=p)
		{
			CP cur=(CP){1,0};
			for(R int l(k);l<k+p/2;++l)
			{
				CP tt=f[l+p/2]*cur;
				f[l+p/2]=f[l]-tt;
				f[l]=f[l]+tt;
				cur=cur*buf;
			}
		}
	} 
}
int main(void)
{
	n=read(),m=read();
	for(R int i=0;i<=n;++i)scanf("%lf",&f[i].x);
	for(R int i=0;i<=m;++i)scanf("%lf",&g[i].x);
	for(m+=n,n=1;n<=m;n<<=1);
	for(R int i(0);i<=n;tr[i]=(tr[i>>1]>>1)|((i&1)?(n>>1):0),++i);
	FFT(f,1),FFT(g,1);
	for(R int i=0;i<n;++i)f[i]=f[i]*g[i];
	FFT(f,-1);
	for(R int i=0;i<=m;++i)printf("%lld ",(int)(f[i].x/n+0.49));
	return 0;
}

快速数论变换(NTT)

于是可以引入基于原根的快速数论变换,可以在模意义下计算多项式的卷积.

(a,m)=1 ,使得 a^l \equiv 1 ~(mod ~ m) 的最小的 l 称为 a 关于模 m 的阶,记作 ord_ma .

(g,m)=1 ,若ord_mg=\varphi(m) ,则称 gm 的原根.

gm 的原根当且仅当 \{g,g^2,g^3,\cdots,g^{\varphi(m)}\}m 的简化剩余系.

m有 原根则一定满足下列形式: 2,4,p^a,2p^a , p 为奇素数, a为正整数.

m 为素数则这个简化剩余系取遍 \{1,2,3,\cdots,m-1\}g,g^2,g^3,\cdots,g^{m-1} 两两不同,这一点在快速数论变换中非常重要.

对于素数 m=t2^k+1 ,我们定义 g^{\frac{m-1}{n}}g_n ,因为 m=t2^k+1 所以一定能取整,为了把它带入DFT,现在考虑证明 g_n 是否满足 \omega_n 的所有性质.

首先 g_n^k=g_{2n}^{2k} 是显然的.

然后 g_n^n=1 ~(mod ~m) 可以写成 g^{m-1} \equiv 1~(mod~m) ,费马小定理,得证.

还有 g_n^{k+n/2}=-g_n^k :

g_n^{k+n/2}=g^{(\frac{(m-1)}{n})^k}*g^{\frac{m-1}{2}}=g_n^k*g_n^{2/n}

又有(g_n^{n/2})^2=g_n^n=1

所以 g_n^{n/2} 等于 1-1 ,又因为 g,g^2,g^3,\cdots,g^{m-1} 两两不同且 g_n^n=1 ,所以 g_n^{n/2}=-1 ,因为是模意义下所以也可以说等于 m-1 .

IDFT中的 g_n^{-1} 取一个逆元即可.

所以只需对单位复根做一个简单的替换就能得到快速数论变换了.

素数和原根一般取 m=1004535809=479\times 2^{21}+1,g=3m=998244353=119*2^{23}+1,g=3 .

代码实现:

#include "iostream"
#include "cstdio"
#include "cstring"
#include "cstdlib"
#include "cmath"
#include "cctype"
#include "algorithm"
#include "set"
#include "queue"
#include "map"
#include "vector"
#include "stack"
#define lxl long long
#define R register
#define INF 2147483647
#define debug(x) printf("debug:%lld\n",x)
using namespace std;
const lxl G=3,mod=998244353,maxn=1350000;
lxl n,m,InvN,InvG;
lxl f[maxn<<1],g[maxn<<1],tr[maxn<<1];
inline lxl read()
{
    char c;
    lxl f=1,x=0;
    for(;!isdigit(c);c=getchar())(c=='-')&&(f=-1);
    for(;isdigit(c);c=getchar())x=(x<<1)+(x<<3)+(c^48);
    return f*x;
}
inline lxl QPow(lxl x,lxl y,lxl t)
{
    lxl base=1;
    for(;y;x=x*x%t,y>>=1)(y&1)&&(base=base*x%t);
    return base;
}
inline void NTT(lxl *f,bool falg)
{
    for(R lxl i=0;i<n;++i)if(i<tr[i])swap(f[i],f[tr[i]]);
    for(R lxl p=2;p<=n;p<<=1)
    {
        lxl len=p>>1,tg=QPow(falg==true?G:InvG,(mod-1)/p,mod);
        for(R lxl k=0;k<n;k+=p)
        {
            lxl buf=1;
            for(R lxl l=k;l<k+len;++l)
            {
                lxl tt=(buf*f[l+len])%mod;
                f[l+len]=f[l]-tt;
                if(f[l+len]<0)f[l+len]+=mod;
                f[l]=f[l]+tt;
                if(f[l]>mod)f[l]-=mod;
                buf=buf*tg%mod;
            }
        }
    }
}
int main(void)
{
    n=read()+1,m=read()+1;
    for(R lxl i=0;i<n;++i)f[i]=read();
    for(R lxl i=0;i<m;++i)g[i]=read();
    for(m=n+m,n=1;n<m;n<<=1);
    InvN=QPow(n,mod-2,mod),InvG=QPow(G,mod-2,mod);
debug(InvN),debug(n);
    for(R lxl i=0;i<n;++i)tr[i]=(tr[i>>1]>>1)|((i&1)?n>>1:0);
    NTT(f,true),NTT(g,true);
    for(R lxl i=0;i<n;++i)f[i]=f[i]*g[i]%mod;
    NTT(f,false);
    for(R lxl i=0;i<m-1;++i)printf("%lld ",f[i]*InvN%mod);
    printf("\n");
    return 0;
}

编辑于 08-07