目录

问题描述

给定一个n-1次多项式$f(x)$,保证$a_0=1$。求$\ln(f(x))$对$x^n$取模的结果。系数模998244353

$\ln(f(x))$定义为其幂级数展开,对$x^n$取模为其幂级数的前n项和。

解决方法

前置知识

多项式乘法(NTT),多项式求逆,多项式求导、积分(这个所有人都会)

正文

设$g(x)=\ln(f(x))$,则$g'(x)\equiv\frac{f'(x)}{f(x)}\equiv f'(x)f^{-1}(x)\pmod{x^n}$。

由多项式求逆算得$f^{-1}(x)$(对$x^n$取模),求导得到$f'(x)$,然后NTT算得$f'(x)f^{-1}(x)$,这就是$g'(x)$对$x^n$取模的结果。

最后再积分得到$g(x)$即可。显然,g(x)的常数项应该是0

问题分析

时间复杂度O($n\log n$),空间复杂度O(n)

模板:洛谷4725

#include<iostream>
#include<cstdio>
#include<cstring>
#define N 200005
#define LL long long
using namespace std;
 
const LL mod=998244353;
int n,m;
int rank[N<<1];
LL f1[N<<1],f2[N<<1];
LL ksm(LL x,LL y)
{
	LL res=1;
	while(y)
	{
		if(y&1) res=res*x%mod;
		y>>=1;
		x=x*x%mod;
	}
	return res;
}
void ntt(LL *a,int type)
{
	int i,j,k;
	for(i=0;i<n;i++)
		if(i<rank[i]) swap(a[i],a[rank[i]]);
	for(i=1;i<n;i<<=1)
	{
		LL t1=(type>0)?ksm(3,(mod-1)/(i<<1)):ksm((mod+1)/3,(mod-1)/(i<<1));
		for(j=0;j<n;j+=(i<<1))
		{
			LL t2=1;
			for(k=0;k<i;k++)
			{
				LL t3=a[j+k],t4=t2*a[j+k+i]%mod;
				a[j+k]=(t3+t4)%mod;
				a[j+k+i]=(t3-t4+mod)%mod;
				t2=t2*t1%mod;
			}
		} 
	}
}
void mul(LL *a1,LL *a2,int l)
{
	int i;
	if(n<l||(n>>1)>=l)
	{
		n=1;
		while((1<<n)<l) n++;
		for(i=0;i<(1<<n);i++)
			rank[i]=(rank[i>>1]>>1)|((i&1)<<n-1);
		n=1<<n;
	}
	ntt(a1,1); ntt(a2,1);
	for(i=0;i<n;i++)
		a1[i]=a1[i]*a2[i]%mod;
	ntt(a1,-1);
	LL inv=ksm(n,mod-2);
	for(i=0;i<n;i++)
		a1[i]=a1[i]*inv%mod;
}
LL tf0[N<<1],tf1[N<<1],tf2[N<<1],tf3[N<<1];
void invf(LL *a,int l)
{
	int i,j;
	tf0[0]=ksm(a[0],mod-2);
	for(i=2;i<(l<<1);i<<=1)
	{
		for(j=0;j<i;j++)
			tf1[j]=tf2[j]=tf0[j];
		mul(tf1,tf2,i<<1);
		memset(tf2+i,0,sizeof(LL)*i);
		for(j=0;j<i;j++)
			tf2[j]=a[j];
		mul(tf1,tf2,i<<1);
		for(j=0;j<i;j++)
			tf0[j]=(tf0[j]*2-tf1[j]+mod)%mod;
	}
	for(i=0;i<l;i++)
		a[i]=tf0[i];
}
void diff(LL *a,int l)
{
	int i;
	for(i=0;i<l;i++)
		a[i]=a[i+1]*(i+1)%mod;
}
void calc(LL *a,int l)
{
	int i;
	for(i=l;i;i--)
		a[i]=a[i-1]*ksm(i,mod-2)%mod;
	a[0]=0;
}
void log(LL *a,int l)
{
	int i;
	for(i=0;i<l;i++)
		tf3[i]=a[i];
	invf(a,l);
	diff(tf3,l);
	mul(a,tf3,l<<1);
	calc(a,l-1);
    memset(a+l,0,sizeof(LL)*((N<<1)-l-1));
}
int main()
{
	int i,j;
	scanf("%d",&m);
	for(i=0;i<m;i++)
		scanf("%lld",&f1[i]);
	log(f1,m);
	for(i=0;i<m;i++)
		printf("%lld ",f1[i]);
	return 0;
}