给定一个n-1次多项式$f(x)$,保证$a_0=1$。求$\ln(f(x))$对$x^n$取模的结果。系数模998244353
$\ln(f(x))$定义为其幂级数展开,对$x^n$取模为其幂级数的前n项和。
多项式乘法(NTT),多项式求逆,多项式求导、积分(这个所有人都会)
设$g(x)=\ln(f(x))$,则$g'(x)\equiv\frac{f'(x)}{f(x)}\equiv f'(x)f^{-1}(x)\pmod{x^n}$。
由多项式求逆算得$f^{-1}(x)$(对$x^n$取模),求导得到$f'(x)$,然后NTT算得$f'(x)f^{-1}(x)$,这就是$g'(x)$对$x^n$取模的结果。
最后再积分得到$g(x)$即可。显然,g(x)的常数项应该是0
时间复杂度O($n\log n$),空间复杂度O(n)
模板:洛谷4725
#include<iostream> #include<cstdio> #include<cstring> #define N 200005 #define LL long long using namespace std; const LL mod=998244353; int n,m; int rank[N<<1]; LL f1[N<<1],f2[N<<1]; LL ksm(LL x,LL y) { LL res=1; while(y) { if(y&1) res=res*x%mod; y>>=1; x=x*x%mod; } return res; } void ntt(LL *a,int type) { int i,j,k; for(i=0;i<n;i++) if(i<rank[i]) swap(a[i],a[rank[i]]); for(i=1;i<n;i<<=1) { LL t1=(type>0)?ksm(3,(mod-1)/(i<<1)):ksm((mod+1)/3,(mod-1)/(i<<1)); for(j=0;j<n;j+=(i<<1)) { LL t2=1; for(k=0;k<i;k++) { LL t3=a[j+k],t4=t2*a[j+k+i]%mod; a[j+k]=(t3+t4)%mod; a[j+k+i]=(t3-t4+mod)%mod; t2=t2*t1%mod; } } } } void mul(LL *a1,LL *a2,int l) { int i; if(n<l||(n>>1)>=l) { n=1; while((1<<n)<l) n++; for(i=0;i<(1<<n);i++) rank[i]=(rank[i>>1]>>1)|((i&1)<<n-1); n=1<<n; } ntt(a1,1); ntt(a2,1); for(i=0;i<n;i++) a1[i]=a1[i]*a2[i]%mod; ntt(a1,-1); LL inv=ksm(n,mod-2); for(i=0;i<n;i++) a1[i]=a1[i]*inv%mod; } LL tf0[N<<1],tf1[N<<1],tf2[N<<1],tf3[N<<1]; void invf(LL *a,int l) { int i,j; tf0[0]=ksm(a[0],mod-2); for(i=2;i<(l<<1);i<<=1) { for(j=0;j<i;j++) tf1[j]=tf2[j]=tf0[j]; mul(tf1,tf2,i<<1); memset(tf2+i,0,sizeof(LL)*i); for(j=0;j<i;j++) tf2[j]=a[j]; mul(tf1,tf2,i<<1); for(j=0;j<i;j++) tf0[j]=(tf0[j]*2-tf1[j]+mod)%mod; } for(i=0;i<l;i++) a[i]=tf0[i]; } void diff(LL *a,int l) { int i; for(i=0;i<l;i++) a[i]=a[i+1]*(i+1)%mod; } void calc(LL *a,int l) { int i; for(i=l;i;i--) a[i]=a[i-1]*ksm(i,mod-2)%mod; a[0]=0; } void log(LL *a,int l) { int i; for(i=0;i<l;i++) tf3[i]=a[i]; invf(a,l); diff(tf3,l); mul(a,tf3,l<<1); calc(a,l-1); memset(a+l,0,sizeof(LL)*((N<<1)-l-1)); } int main() { int i,j; scanf("%d",&m); for(i=0;i<m;i++) scanf("%lld",&f1[i]); log(f1,m); for(i=0;i<m;i++) printf("%lld ",f1[i]); return 0; }