FFT NTT小结

一个问题

给定\(a_i,b_i\),定义\(c_n=\sum_{i=0}^n a_ib_{n-i}\),求\(c\)序列。

想法

暴力乘需要\(O(nm)\),有没有办法优化呢?

能不能把多项式转化一下,然后快速的计算再转回来呢?

答案是能的。

多项式有两种表示法:点值表示法和系数表示法。

系数表示法是用n+1个数\(a_i\)来表示,也就是\(F(x)=\sum_{i=0}^na_ix^i\)

点值表示法是用n+1个点\((x_i,y_i)\)来表示,也就是\(x=x_i\)\(F(x)=y_i\)

表示出来以后,可以直接\(O(n)\)计算:如果x相同的两个点分别是\((x,f(x)),(x,g(x))\),那么乘积显然是

\((x,f(x)g(x))\)

当然由于乘积的次数是\(n+m\)次的,至少要用\(n+m+1\)个点。

问题就在于如何在系数表示法和点值表示法之间转换。

暴力用\(x=0..n+m\)带入计算然后之后插值计算原函数吗?先不说插值的复杂度(我也不知道没写过),点值计算的复杂度就已经有\(O((n+m)^2)\) ,好垃圾啊

单位复数根

\(\omega^n=1\),则\(\omega\)有n个解,在复平面单位圆上,排成了一个正n边形,其中一个顶点是1。

定义\(\omega_n\)为最小的那个解,也就是从1逆时针方向第一个解,比如\(\omega_4=i\)

那么\(\omega_n^i\)其中\(i\)从0到n-1刚好构成了全部的解。(\(\omega_n^0=1\))

这个东西满足一些性质:

  1. \(\omega_n^i=\omega_n^{i+n}\)
  2. 如果\(n\)为偶数,(n=2k),则\(\omega_n^i+\omega_n^{k+i}=0\)
  3. 如果\(n\)为偶数,(n=2k),则\(\omega_n^{2i}=\omega_k^i\)
  4. \(\frac{1}{n}\sum_{i=0}^{n-1}\omega_n^{vi}=[n|v]\)

这些性质很重要

则对于任意\(a,b,c\),显然

\[ \begin{align} [n|a+b-c]&=\sum_{i=0}^{n-1}\omega_n^{(a+b-c)i}\\ &=\frac{1}{n}\sum_{i=0}^{n-1}\omega_n^{ai}\omega^{bi}\omega^{-ci} \end{align} \]

成立

DFT

\(n\)是大于c长度的最小的形如\(2^k\)的数。

\[ \begin{align} c_m&=\sum_{i}\sum_{j}[i+j=m]a_ib_j\\ &=\sum_{i}\sum_{j}[n|i+j-m]a_ib_j\\ &=\frac{1}{n}\sum_{i}\sum_{j}\sum_{v=0}^{n-1}a_i\omega_n^{iv}b_j\omega_n^{jv}\omega^{-mv}\\ &=\frac{1}{n}\sum_{v=0}^{n-1}\omega^{-mv}\sum_{i}a_i\omega_n^{iv}\sum_{j}b_j\omega_n^{jv}\\ \end{align} \]\[ A(x)=\sum_{i=0}^{n-1}a_ix^i\\ B(x)=\sum_{i=0}^{n-1}b_ix^i\\ c_m=\frac{1}{n}\sum_{i=0}^{n-1}C_i\omega_n^{-im}\\ C(x)=\sum_{i=0}^{n-1}C_ix^i \]

\[ C_v=A(\omega_n^v)B(\omega_n^v)\\ c_m=\frac{1}{n}\sum_{v=0}^{n-1}\omega^{-mv}A(\omega_{n}^v)B(\omega_{n}^v)\\ c_m=\frac{1}{n}C(\omega^{-m}) \] 所以我们就只需要从\(a_{0}\dots a_{n-1}\)计算\(A(\omega_n^{0})\dots A(\omega_n^{n-1})\),以及从\(C_i\)计算\(c_{i}\),前者是从系数表示法转换为点值表示(DFT),后者点值表示转换为系数表示(IDFT)。

FFT

一个引理:设\(F(x)=a_0+a_1x+a_2x^2+...\)\(F_0(x)=a_0+a_2x+a_4x^2+a_6x^3+...\),\(F_1(x)=a_1+a_3x+a_5x^2+...\),那么\(F(x)=F_0(x^2)+xF_1(x^2)\)

所以

\[ F(\omega_n^m)=F_0(\omega_{\frac{n}{2}}^m)+\omega_n^mF_1(\omega_{\frac{n}{2}}^m)\\ F(\omega_n^{m+\frac{n}{2}})=F_0(\omega_{\frac{n}{2}}^m)+\omega_n^{m+\frac{n}{2}}F_1(\omega_{\frac{n}{2}}^m)=F_0(\omega_{\frac{n}{2}}^m)-\omega_n^{m}F_1(\omega_{\frac{n}{2}}^m) \]

所以我们只需要先给\(F_0\)\(F_1\)DFT,然后\(O(n)\)合并。

显然\(F_0\)\(F_1\)都是\(\frac{n}{2}\)次的,那么复杂度为\(T(n)=2T(\frac{n}{2})+O(n)\)根据主定理复杂度为\(T(n\log n)\)

那IDFT怎么计算呢?

根据前面的式子,IDFT只需要把\(\omega_n^m\)改成\(\omega_n^{-m}\)再DFT一遍最后每一项\(/n\)就好了。

或者还有一种写法:直接再DFT一遍,然后把\(c_1\dots c_{n-1}\)反转,然后把每一项\(/n\)

至于单位复数根怎么计算呢?

根据欧拉定理,\(\omega_n^m=\cos(\frac{2\pi m}{n})+i\sin(\frac{2\pi m}{n})\)

当然也可以得到了\(\omega_n^1\)\(\omega_n^m=\omega_n^{m-1}\omega_n^1\)递推。

但是直接递归有点慢....因为要反复的把元素放进两个数组再合并回来,有没有办法优化呢?

发现递归的过程形如这样(系数表示法)设n=8: \[ (a_0,a_1,a_2,a_3,a_4,a_5,a_6,a_7)\\ (a_0,a_2,a_4,a_6)(a_1,a_3,a_5,a_7)\\ (a_0,a_4)(a_2,a_6)(a_1,a_5)(a_3,a_7)\\ (a_0)(a_4)(a_2)(a_6)(a_1)(a_5)(a_3)(a_7) \] 如果我们先按最下面的顺序排好,然后向上一层一层的合并,是不是就节省了很多时间呢?

发现(0,4,2,6,1,5,3,7)和(0,1,2,3,4,5,6,7)二进制恰好相反,这是不是巧合?

仔细思考:

把这n个元素分成两组的时候,是按照最低位为0还是1分的,为0则会分到前面一半(最高位为0),为1则会分到后面一半。

也就是最低位会对应到换之后的最高位,然后次低位也是一样会对应到次高位....

所以就是每个元素编号的二进制位翻转。

这个称作位逆序置换。

如何计算位逆序置换呢?

\(rev(x)\)表示x的位逆序置换,那么\(rev(x)=rev(\lfloor\frac{x}{2}\rfloor)+(x\bmod 2)\times\frac{n}{2}\)

这个不难证明 。

代码:

typedef complex<double> fs;
void FFT(fs *a,int flag){
    for(int i=0;i<s;++i)if(i<rev[i])swap(a[i],a[rev[i]]);   
    for(int i=1;i<s;i<<=1){
        fs wa(cos(pi/i),flag*sin(pi/i));
        for(int j=0;j<s;j+=i+i){
            fs w(1,0),u,v;
            for(int k=0;k<i;++k){
                u=a[j+k],v=w*a[j+k+i];
                a[j+k]=u+v;
                a[j+k+i]=u-v;
                w=w*wa;
            }
        }
    }
    if(flag<0){
        for(int i=0;i<s;++i)a[i]/=s;
    }
}

FFT(a,1)代表对a进行DFT,FFT(a,-1)代表IDFT

当然还有一种常数优化:手写复数类型而不是用complex。

NTT

还有一些题目,要求膜意义下卷积,也就是卷出来的系数要膜dkw一个数。

一般膜的数可以表示为\(2^ab+1\)的形式,其中\(2^a\)要不小于n。(否则就需要一些奇怪的操作)

思考\(\omega\)用在FFT中的性质 这些在膜意义下有什么东西替代吗?

答案是原根。

由于费马小定理\(g^{p-1}=1\),我们可以用\(g^\frac{p-1}{n}\)代替FFT的\(\omega_n^1\),发现也同样满足这些性质。

然后对应的IDFT 用的就是\(g^{-\frac{p-1}{n}}\)

代码:

const ll p=998244353;
const ll g=3;
ll fpm(ll a,ll b){
    ll c=1;
    while(b){
        if(b&1)c=c*a%p;
        a=a*a%p;
        b>>=1;
    }
    return c;
}

ll invs;
void NTT(ll *a,ll flag){
    for(ll i=0;i<s;++i)if(i<rev[i])swap(a[i],a[rev[i]]);
    for(ll i=1;i<s;i<<=1){
        ll w=fpm(g,((flag+(p-1))*(p-1)/i/2));
        for(ll j=0;j<s;j+=i+i){
            ll aw=1,u,v;
            for(ll k=0;k<i;++k){
                u=a[j+k],v=a[j+k+i]*aw%p;
                a[j+k]=(u+v)%p;
                a[j+k+i]=(u-v+p)%p;
                aw=aw*w%p;
            }
        }
    }
    if(flag<0)for(ll i=0;i<s;++i)a[i]=a[i]*invs%p;
}

例题

UOJ34 多项式乘法

一个模板题。

FFT和NTT都可以

代码:

FFT NTT

ZJOI 2014 力

\[ E_i=\sum_{j<i}q_j\frac{1}{(i-j)^2}-\sum_{i<j}q_j\frac{1}{(j-i)^2} \]

\[ F_i=\sum_{j<i}q_j\frac{1}{(i-j)^2}\\ G_i=\sum_{i<j}q_j\frac{1}{(j-i)^2}\\ E_i=F_i-G_i \]

\(F_i\)一看就能直接卷积,后面一半感觉也像卷积形式,可是\(j>i\)怎么解决呢?

\(q_j,G_i\)都翻转就能卷积了。

fs q[1<<19],qr[1<<19],k[1<<19],f[1<<19],F[1<<19];
int rev[1<<19];
 
void FFT(fs *a,int base,int flag){
    for(int i=0;i<base;++i)if(i<rev[i]){
        swap(a[i],a[rev[i]]);   
//      printf("Swap%d,%d\n",i,rev[i]);
    }
    fs ww,w,u,v;
    for(int S=2,s=1;s<base;S<<=1,s<<=1){
        ww=fs(cos(pi/s),flag*sin(pi/s));
        for(int l=0;l<base;l+=S){
            w=fs(1,0);
            for(int j=0;j<s;++j){
                u=a[l+j];
                v=w*a[l+j+s];
                a[l+j]=u+v;
                a[l+j+s]=u-v;
                w=w*ww;
            }
        }
    }
    if(flag<0)for(int i=0;i<base;++i)a[i].x/=base,a[i].y/=base;
}
int main(){
#ifdef cnyali_lk
    freopen("3527.in","r",stdin);
    freopen("3527.out","w",stdout);
#endif  
    int n;
    scanf("%d",&n);
    for(int i=1;i<=n;++i){
        scanf("%lf",&q[i].x);
        qr[n+1-i]=q[i];//翻转
        k[i].x=1./i/i;
    }
    int m=n<<1,base=1,l=-1;
    while(base<=m){base<<=1;++l;}
 
    for(int i=0;i<base;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<l);
    FFT(q,base,1);
    FFT(qr,base,1);
    FFT(k,base,1);
 
    for(int i=0;i<base;++i)f[i]=q[i]*k[i],F[i]=qr[i]*k[i];
    FFT(f,base,-1);
    FFT(F,base,-1);
    for(int i=1;i<=n;++i)
        printf("%lf\n",f[i].x-F[n+1-i].x);//翻转
    return 0;
}