快速计算数列a的多点等幂和

快速计算数列a的多点等幂和

问题概述:

给定数列 \{a_n\} ,上界 m ,对于所有的 k \leq m , 求\sum_{i=1}^{n}{a_i^k} 的值。

更新一个新的方法:

直接写出答案的生成函数: H(x)=\sum_{i\ge0}\sum_{j=1}^n {a_j}^ix^i=\sum_{j=1}^n\sum_{i\ge0} {(a_jx)}^i=\sum_{i=1}^n\frac1{1-a_ix}=\frac{\sum_{i=1}^n{\prod_{j\not =i}(1-a_jx)}}{\prod_{i=1}^n(1-a_ix)}

F(x)=\prod_{i=1}^n(1-a_ix),G(x)=\sum_{i=1}^n{\prod_{j\not =i}(1-a_jx)}F(x) 可以利用分治FFT在 O(n\log^2 n) 算出来。而 G(x)x^i 的系数恰好就是 F(x) 对应项系数乘上 n-i ,即 [x^i]G(x)=(n-i)[x^i]F(x)

考虑式子的组合意义, F(x)x^i 的系数表示从这 n 个数中选出 i 个数的乘积之和再乘上 (-1)^i ,稍微观察一下 G(x) 的形式,就会发现对于一个选取了 j 个数的方案,它会在 G(x) 中重复出现 n-j 次,所有选取数量相同的方案被算重的方案数是相同的,那么直接对 F(x) 的系数乘上被算重次数 n-i ,就是 G(x) 的系数了。

设数列长度为 n ,需要计算的幂和的上限为 m ,那么分治FFT时间复杂度 O(n\log^2n) ,求逆+乘法时间复杂度 O(m\log m) ,所以时间复杂度 O(n\log^2n+m\log m) 。与方法二相同但是不需要多项式取对。


原文:

注意到直接计算该和式需要 O(n\log m)m 次计算则需要花费 O(nm\log m) 的时间。

如果我们改变枚举顺序,时间复杂度可降至 O(nm)

对于数列 \{a_n\} ,我们可以设出一个生成函数 E(x)=\prod{(1+a_ix)}。似乎这个生成函数也没有什么特殊的性质,尝试对其求对数:

G(x)=\ln E(x) ,那么有:

G(x)=\ln \prod_{i=1}^{n}(1+a_ix)=\sum_{i=0}^{n} \ln(1+a_ix)\\

利用 \ln (x+1)x=0 处展开的麦克劳林展开式 \ln (1+ax)=\sum_{n \geq 1}{\frac{(-1)^{n+1}}{n} x^n} ,可得: \begin{align} G(x)&=\sum_{i=1}^{n}{\ln(1+a_ix)}\\ &=\sum_{i=1}^{n}{\sum_{k\geq1}{\frac{(-1)^{k+1}}{k}}a_i^kx^k}\\ \end{align}\\

发现生成函数中出现了 a_i^k 的结构。接下来尝试变换枚举顺序:

\begin{align} G(x)&=\sum_{k \geq 1}\sum_{i=1}^{n}{\frac{(-1)^{k+1}}{k}a_i^kx^k}\\ &=\sum_{k \geq 1}\frac{(-1)^{k+1}}{k}\sum_{i=1}^{n}{a_i^kx^k}\\ \end{align}\\

式子中出现了我们求和的东西: \sum_{i=1}^{n}{a_i^k} 。接下来的难点便在于如何快速求出 G(x)

首先 E(x)=\prod_{i=1}^{n}(1+a_ix)

P(x)=\prod_{i=1}^{\lceil \frac{n}{2} \rceil}(1+a_ix), Q(x)=\prod_{i=\lceil \frac{n}{2} \rceil+1}^{n}(1+a_ix),则 E(x)=P(x)Q(x)

结合FFT以及分治算法,E(x) 可以在 O(n\log^2n) 时间内被算出。

接下来计算 G(x)=\ln E(x) 。由于题目给定了K,只需要知道 G(x) 的前K+1项,那么我们把式子放在取模的意义下:

G(x) \equiv \ln E(x)\mod(x^{m+1})\\

两边同时对 x 求导:

G'(x)\equiv \frac{E'(x)}{E(x)}\mod(x^{m+1})\\

同时积分:

G(x)\equiv \int\frac{E'(x)}{E(x)}dx\mod(x^{m+1})\\

我们得到了 G(x) 的表达式,其中多项式求逆 O(m\log m) ,求导以及积分 O(m)

至此问题得以全部解决,总时间复杂度 O(n\log^2n+m\log m)


例题:玩游戏(【P4705】玩游戏 - 洛谷

Alice 和 Bob 又在玩游戏。
对于一次游戏,首先 Alice 获得一个长度为 n 的序列 \{a_n\} ,Bob 获得一个长度为 m 的序列 \{b_n\} 。之后他们各从自己的序列里随机取出一个数,分别设为 a_x,b_y ,定义这次游戏的 k 次价值为 (a_x+b_y)^k
由于他们发现这个游戏实在是太无聊了,所以想让你帮忙计算对于 i=1,2,3,...,t ,一次游戏 i 次价值的期望是多少。
由于答案可能很大,只需要求出模  998244353 下的结果即可。

输入格式:

第一行两个整数 n,m(1≤n,m≤10^5) ,分别表示 Alice 和 Bob 序列的长度。
接下来一行 n 个数,第 ii 个数为 a_i​(0≤a_i​<998244353) ,表示 Alice 的序列。
接下来一行 m 个数,第 j 个数为b_j​(0≤b_j​<998244353) ,表示 Bob 的序列。
接下来一行一个整数 t (1≤t≤10^5) ,意义如上所述。

输出格式:

t 行,第 i 行表示一次游戏 i 次价值的期望。

样例输入:

2 8
764074134 743107904
663532060 183287581 749169979 7678045 393887277 27071620 13482818 125504606
6

样例输出:

774481679
588343913
758339354
233707576
36464684
461784746


对于某一个 kk 次价值期望值为 E_k=\frac{\sum_{i=1}^{n}\sum_{i=1}^{m}{(a_i+b_y)^k}}{nm} 。那么我们只需要计算分子和式,最后除以定值 nm 就行了。

\begin{align} nmE_k&=\sum_{i=1}^{n}\sum_{j=1}^{m}(a_i+b_j)^k\\ &=\sum_{i=1}^{n}\sum_{j=1}^{m}\sum_{p=1}^{p}\binom{k}{p}a_i^pb_j^{k-p}\\ &=\sum_{p=1}^{k}\binom{k}{p}\sum_{i=1}^{n}a_i^p\sum_{j=1}^{m}b_j^{k-p}\\ &=\sum_{p=1}^{k}\frac{k!}{p!(k-p)!}\sum_{i=1}^{n}a_i^p\sum_{j=1}^{m}b_j^{k-p}\\ &=k!\sum_{p=1}^{k}\frac{\sum_{i=1}^{n}a_i^p}{p!}\frac{\sum_{j=1}^{m}b_j^{k-p}}{(k-p)!}\\ \end{align}\\

发现这是一个卷积的形式,于是令 A(x)=\sum_{p=1}^{k}{\frac{\sum_{i=1}^{n}a_i^p}{p!}}x^p, B(x)=\sum_{p=1}^{k}{ \frac{\sum_{i=1}^{m}b_i^{p}}{p!}x^p }\\ \\

可以利用上文快速求和的方法,计算 A(x),B(x) 的系数。

答案 nmE_k=k![x^k]A(x)B(x) ,可以用FFT快速算出。整体时间复杂度 O(n\log^2n+k\log k)

#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define loop(n,i) for(register int i=1;i<=(n);++i)
#define MAX 6662333
using namespace std;
/*玩游戏*/
inline int getint(){
    char c=getchar();int s=0;bool sign=0;
    while(!isdigit(c)&&c^'-')c=getchar();
    if(c=='-')c=getchar(),sign=1;
    while(isdigit(c))s=(s<<1)+(s<<3)+c-'0',c=getchar();
    return sign?-s:s;
}
#define P 998244353
#define int long long
struct Poly{
    int a[MAX];
    int length;
    int& operator [](const int x){return a[x];}
    const int operator [](const int x)const{return a[x];}
    inline void operator *=(const Poly& X){
        for(register int i=0,j=1<<max(length,X.length);i<j;++i) a[i]=a[i]*X[i]%P;
    }
    Poly(int x=0):length(x){}
};
ostream& operator <<(ostream& out,Poly& X){
    for(register int i=0,j=1<<X.length;i<j;++i) out<<X[i]<<' ';return out;
}
int R[MAX];
inline int Quick(int a,int m){int ans=1;for(;m;m>>=1,a=a*a%P) if(m&1) ans=ans*a%P;return ans;}
inline void NTT(Poly& a,int type){
    int x=a.length,n=1<<x;
    for(register int i=0;i<n;++i) if((R[i]=R[i>>1]>>1|(i&1)<<x-1)>i) swap(a[i],a[R[i]]);
    for(register int s=1;s<=x;++s){
        int len=1<<s,mid=len>>1;
        int w=Quick(3,P-1+type*(P-1)/len);
        for(register int k=0;k<n;k+=len){
            int d=1;
            for(register int j=0;j<mid;++j){
                int u=a[j+k],v=a[j+k+mid]*d%P;
                a[j+k]=(u+v)%P;
                a[j+k+mid]=(u-v+P)%P;
                d=d*w%P;
            }
        }
    }
    if(type==-1) for(register int i=0,inv=Quick(n,P-2);i<n;++i) a[i]=a[i]*inv%P;
}

Poly A;
inline void Inverse(Poly& a,Poly& b){
    for(register int i=0,j=1<<a.length+1;i<j;++i) b[i]=A[i]=0;
    b[0]=Quick(a[0],P-2);
    for(register int s=1;s<=a.length;++s){
        int len=1<<s,Len=len<<1;
        for (register int i=0;i<len;++i) A[i]=a[i];
        b.length=A.length=s+1;
        NTT(b,1),NTT(A,1);
        for(register int i=0;i<Len;++i) b[i]=(2*b[i]%P-b[i]*b[i]%P*A[i]%P+P)%P;
        NTT(b,-1);
        for(register int i=len;i<Len;++i) b[i]=0;
    }
    b.length=a.length;
}

inline void Derivation(Poly& a,Poly& b){
    for(register int i=0,j=1<<a.length+1;i<j;++i) b[i]=0;
    for(register int i=0,j=1<<a.length;i<j;++i) b[i]=a[i+1]*(i+1)%P;
    b[(1<<(b.length=a.length))-1]=0;
}
inline void Integrate(Poly& a,Poly& b){
    for(register int i=0,j=1<<a.length+1;i<j;++i) b[i]=0;
    R[1]=1;for(register int i=2,j=1<<a.length;i<j;++i) R[i]=(P-P/i)*R[P%i]%P;
    for(register int i=0,j=1<<a.length;i<j;++i) b[i+1]=a[i]*R[i+1]%P;
    b[0]=0;b.length=a.length+1;
}

Poly C;
inline void Ln(Poly& a,Poly& b){
    for(register int i=0,j=1<<a.length+1;i<j;++i) b[i]=C[i]=0;
    Derivation(a,b);
    Inverse(a,C);
    C.length++,b.length++;
    NTT(C,1),NTT(b,1);C*=b;NTT(C,-1);
    C.length--,b.length--;
    for(register int i=1<<C.length,j=1<<C.length+1;i<j;++i) b[i]=0;
    Integrate(C,b);
    b.length=a.length;
}

int memo[MAX];
Poly AA,BB;
int Calculate(int l,int r,const Poly& a){
    if(l==r) return memo[l]=a[l],1;
    int mid=l+r>>1;
    AA[0]=Calculate(l,mid,a);
    BB[0]=Calculate(mid+1,r,a);
    AA.length=0;
    while((1<<AA.length)<=(r-l+1)) AA.length++;BB.length=AA.length;
    for(register int i=1,j=1<<AA.length;i<j;++i) AA[i]=BB[i]=0;
    for(register int i=l;i<=mid;++i) AA[i-l+1]=memo[i];
    for(register int i=mid+1;i<=r;++i) BB[i-mid]=memo[i];
    NTT(AA,1),NTT(BB,1);AA*=BB;NTT(AA,-1);
    for(register int i=l;i<=r;++i) memo[i]=AA[i-l+1];
    return AA[0];
}

int n,m,t;
Poly a,b,temp;
main(){
    n=getint(),m=getint();
    loop(n,i) a[i]=getint();
    loop(m,i) b[i]=getint();
    t=getint(); 
    while((1<<a.length)<=t) a.length++;b.length=a.length;
    
    memo[0]=Calculate(1,n,a);
    for(register int i=0;i<=n;++i) a[i]=memo[i];
    Ln(a,temp);a=temp;
    memo[0]=Calculate(1,m,b);
    for(register int i=0;i<=m;++i) b[i]=memo[i];
    Ln(b,temp);b=temp;
    
    R[1]=1;for(register int i=2;i<=t;++i) R[i]=(P-P/i)*R[P%i]%P;
    for(register int k=1,sign=1,fact=1;k<=t;++k,sign=P-sign,fact=fact*R[k]%P){
        a[k]=a[k]*k%P*sign%P*fact%P;
        b[k]=b[k]*k%P*sign%P*fact%P;
    }
    a[0]=n,b[0]=m;
    a.length++,b.length++;
    for(register int i=t+1,j=1<<a.length;i<j;++i) a[i]=b[i]=0;
    NTT(a,1),NTT(b,1);a*=b;NTT(a,-1);
    int inv=Quick(n*m%P,P-2);
    for(register int i=1,fact=1;i<=t;++i,fact=fact*i%P){
        cout<<fact*a[i]%P*inv%P<<'\n';
    }
}

编辑于 2019-04-17