Warning: session_start(): open(/tmp/sess_dbfbdd8bff5b68544e35a5a8f7ce62e4, O_RDWR) failed: No space left on device (28) in /data/wiki/inc/init.php on line 239
Warning: session_start(): Failed to read session data: files (path: ) in /data/wiki/inc/init.php on line 239
Warning: Cannot modify header information - headers already sent by (output started at /data/wiki/inc/init.php:239) in /data/wiki/inc/auth.php on line 430
Writing /data/wiki/data/cache/8/8fe637ac40e9dfa91f93fbcd08d065e4.captchaip failed
Warning: Cannot modify header information - headers already sent by (output started at /data/wiki/inc/init.php:239) in /data/wiki/inc/Action/Export.php on line 103
Warning: Cannot modify header information - headers already sent by (output started at /data/wiki/inc/init.php:239) in /data/wiki/inc/Action/Export.php on line 103
Warning: Cannot modify header information - headers already sent by (output started at /data/wiki/inc/init.php:239) in /data/wiki/inc/Action/Export.php on line 103
====== 子集卷积 ======
==== 简介 ====
一般我们有如下一类的状压dp方程,如$dp[i]=\sum dp[j]*w[k]$ ($i,j,k$满足$j\lor k=i,j \land k=0$,这里符号表示按位与和按位或。
如果暴力枚举位的子集,那么效率是$3^n$的,难以承受。
实际上这个已经很接近一个FWT卷积的形式了,只不过是还要$j\land k=0$罢了。
我们改变这个条件为,$j$中1的个数+$k$中1的个数=$i$中1的个数,那么当我们为$dp$增加一个“1的个数”的维度时,问题迎刃而解: $$
dp[cnt_i][i]=\sum_{(j|k)==i}dp[cnt_j][j]*w[cnt_i-cnt_j][k]
$$ 注意这里$cnt_i$表示1的个数,或者说子集中的物品数目。这里$cnt_i$和$i$的二进制1的个数如果不等,这个$dp$或者$w$值会置为0。此时只要我们从小到大枚举$cnt$来做FWT就可以得到答案了,实际操作过程中,所有的$dp$都是点值形式,因此得到新的$dp[cnt_i]$只需要做$cnt_i$次对位乘;最后,再将所有的$dp$逆FWT变换回原值。
虽然牺牲了一定空间,但是时间被优化到了$n$次FWT+$n^2$次对位乘法的复杂度:$O((2^n*n)*n+n^2*2^n)=O(n^2*2^n)$。
==== 例题 ====
模板题:https:%%//%%ac.nowcoder.com/acm/contest/5157/D
很容易从题目的形式看出来实际上就是对四个数列求三重卷积,第一重是$i|j$的子集卷积,第二重$(i|j)+k$的FFT/NTT,第三重是$((i|j)+k)\otimes h$的FWT的异或卷积。
#include
#define N 262144
using namespace std;
const int mod = 998244353, inv2 = 499122177;
int n;
int rev[N], lim, hib;
int A[N], B[N], C[N], D[N], popc[N];
int f[20][N], g[20][N], h[20][N];
inline int Add(int u, int v) { return (u += v) >= mod ? u - mod : u; }
inline void Inc(int &u, int v) { if ((u += v) >= mod) u -= mod; }
inline int fpm(int x, int y) {
int r = 1;
while (y) {
if (y & 1) r = 1LL * x * r % mod;
x = 1LL * x * x % mod, y >>= 1;
}
return r;
}
inline int read() {
int x = 0;
char ch = getchar();
while (!isdigit(ch)) ch = getchar();
while (isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
return x;
}
void getrev(int len) {
lim = 1, hib = -1;
while (lim < len) lim <<= 1, ++hib;
for (int i = 0; i < lim; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << hib);
}
void fwtOr(int *a, bool type) {
for (int mid = 1; mid < lim; mid <<= 1)
for (int i = 0; i < lim; i += (mid << 1))
for (int j = 0; j < mid; ++j)
if (type) Inc(a[i + mid + j], a[i + j]);
else Inc(a[i + mid + j], mod - a[i + j]);
}
void fwtXor(int *a, bool type) {
static int x, y;
for (int mid = 1; mid < n; mid <<= 1)
for (int len = mid << 1, i = 0; i < n; i += len)
for (int j = 0; j < mid; ++j) {
x = a[i + j], y = a[i + mid + j];
a[i + j] = Add(x, y), a[i + mid + j] = Add(x, mod - y);
if (!type)
a[i + j] = 1LL * inv2 * a[i + j] % mod,
a[i + mid + j] = 1LL * inv2 * a[i + mid + j] % mod;
}
}
void NTT(int *a, bool type) {
for (int i = 0; i < lim; ++i)
if (i < rev[i])
swap(a[i], a[rev[i]]);
static int x, y;
for (int mid = 1; mid < lim; mid <<= 1) {
int len = mid << 1, wn = fpm(3, (mod - 1) / len);
for (int i = 0; i < lim; i += len)
for (int j = 0, w = 1; j < mid; ++j, w = 1LL * w * wn % mod) {
x = a[i + j], y = 1LL * w * a[i + mid + j] % mod;
a[i + j] = Add(x, y), a[i + mid + j] = Add(x, mod - y);
}
}
if (!type) {
reverse(a + 1, a + lim);
int inv = fpm(lim, mod - 2);
for (int i = 0; i < lim; ++i)
a[i] = 1LL * inv * a[i] % mod;
}
}
int main() {
n = read(), ++n;
getrev(n + n - 1);
for (int i = 0; i < lim; ++i) popc[i] = popc[i >> 1] + (i & 1);
for (int i = 0; i < n; ++i) A[i] = read(), f[popc[i]][i] = A[i];
for (int i = 0; i < n; ++i) B[i] = read(), g[popc[i]][i] = B[i];
for (int i = 0; i < n; ++i) C[i] = read();
for (int i = 0; i < n; ++i) D[i] = read();
for (int i = 0; i <= 17; ++i)
fwtOr(f[i], true), fwtOr(g[i], true);
for (int sa = 0; sa <= 17; ++sa)
for (int sb = 0; sb + sa <= 17; ++sb)
for (int i = 0; i < lim; ++i)
h[sa + sb][i] = (h[sa + sb][i] + 1LL * f[sa][i] * g[sb][i]) % mod;
for (int i = 0; i <= 17; ++i)
fwtOr(h[i], false);
for (int i = 0; i < lim; ++i)
A[i] = h[popc[i]][i];
NTT(A, true), NTT(C, true);
for (int i = 0; i < lim; ++i)
A[i] = 1LL * A[i] * C[i] % mod;
NTT(A, false);
fwtXor(A, true), fwtXor(D, true);
for (int i = 0; i < lim; ++i)
A[i] = 1LL * A[i] * D[i] % mod;
fwtXor(A, false);
int Q = read();
while (Q--) printf("%d\n", A[read()]);
return 0;
}