Description
在2016年,佳媛姐姐刚刚学习了第二类斯特林数,非常开心。
现在他想计算这样一个函数的值:
表示第二类斯特林数,递推公式为:
边界条件为
你能帮帮他吗?
Input
输入只有一个正整数n。
Output
输出f(n)。由于结果会很大,输出f(n)对998244353(7×17×223+1)取模的结果即可。
Solution
斯特林数表示将
个不同的元素划分成
个不同的集合,且每个集合非空的方案数。
设:
设,即集合是不同的。
于是可以推出:
然后可以理解为每个集合都有两种颜色,枚举时顺便枚举颜色即可。
设:
显然这是个卷积的形式,但是等式两边都有,所以要使用CDQ分治。
就像普通的CDQ一样,每次都用去更新
。
仔细看看就能发现可以把和
作卷积去更新
。
Code
#include<set> #include<map> #include<vector> #include<string> #include<math.h> #include<time.h> #include<stdio.h> #include<stdlib.h> #include<string.h> #include<iostream> #include<algorithm> #define maxn 600000 #define ll long long #define mod 998244353 #define lowbit(x) (x&-x) #define maxint 2147483647 #define abs(x) (a<0?-a:a) #define max(a,b) (a>b?a:b) #define min(a,b) (a<b?a:b) #define ckmax(a,b) (a=max(a,b)) #define ckmin(a,b) (a=min(a,b)) #define isd(c) ('0'<=c&&c<='9') using namespace std; template<typename T>inline void read(T&x){T f=1;x=0;char c; for (c=getchar(); !isd(c); c=getchar()) f=(c=='-')?-1:f; for (; isd(c); c=getchar()) x=(x<<3)+(x<<1)+(c^48);x*=f; } template<typename T>inline void wt(T x,char c='\0'){char ch[(sizeof(T)+1)<<1];if (x<0) putchar('-'),x=-x; int t=-1; do ch[++t]=x%10+'0',x/=10; while(x); do putchar(ch[t]); while(t--); if (c!='\0') putchar(c); } int n,rev[maxn]; ll ans,a[maxn],b[maxn],f[maxn],inv[maxn],fact[maxn],finv[maxn]; inline ll qp(ll a,ll b) {ll ans=1;for (; b; b>>=1,a=a*a%mod)if (b&1) ans=ans*a%mod;return ans;} inline void ntt(ll*c,int n,int type) { for (int i=0; i<n; i++) if (i<rev[i]) swap(c[i],c[rev[i]]); for (int i=1; i<n; i<<=1) { ll x=qp(3,type==1?(mod-1)/(i<<1):mod-1-(mod-1)/(i<<1)); for (int j=0; j<n; j+=(i<<1)) {ll y=1; for (int k=0; k<i; k++,y=y*x%mod) { ll p=c[j+k],q=y*c[i+j+k]; c[j+k]=(p+q)%mod,c[i+j+k]=(p-q+mod)%mod; } } } } inline void cdq(int l,int r) { if (l==r) return;int mid=(l+r)>>1;cdq(l,mid); int lim=r-l+1,n=1,k=0;for (; n<lim; n<<=1) k++; for (int i=0; i<n; i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1)); for (int i=0; i<n; i++) a[i]=b[i]=0; for (int i=l; i<=mid; i++) a[i-l]=f[i]; for (int i=0; i<=r-l; i++) b[i]=finv[i]; ntt(a,n,1),ntt(b,n,1); for (int i=0; i<n; i++) a[i]=a[i]*b[i]%mod; ntt(a,n,-1);ll tmp=qp(n,mod-2); for (int i=0; i<n; i++) a[i]=a[i]*tmp%mod; for (int i=mid+1; i<=r; i++) f[i]=(f[i]+2*a[i-l])%mod; cdq(mid+1,r); } int main() { read(n); f[0]=inv[1]=fact[0]=finv[0]=1; for (int i=1; i<=n; i++) { if (i!=1) inv[i]=(mod-mod/i)*inv[mod%i]%mod; fact[i]=fact[i-1]*i%mod,finv[i]=finv[i-1]*inv[i]%mod; } cdq(0,n); for (int i=0; i<=n; i++) ans=(ans+f[i]*fact[i]%mod)%mod; wt((ans+mod)%mod,'\n');return 0; }