Contents
常用函数与定义
#include <bits/stdc++.h>
using namespace std;
#define SZ(x) ((int)(x).size())
typedef vector<int> VI;
typedef long long ll;
const int P = 998244353;
void print(VI a) {
int n = int(a.size());
for (int i = 0; i < n - 1; i++) printf("%d ", a[i]);
printf("%d\n", a[n - 1]);
}
inline void add(int &x, int y) {
x += y;
if (x >= P) x -= P;
}
inline void sub(int &x, int y) {
x -= y;
if (x < 0) x += P;
}
inline int mul(int x, int y) {
return 1LL * x * y % P;
}
int ksm(int x, int y) {
int res = 1;
for (; y; y >>= 1, x = mul(x, x)) {
if (y & 1) res = mul(res, x);
}
return res;
}
inline int inv(int a) {
a %= P;
if (a < 0) a += P;
int b = P, u = 0, v = 1;
while (a) {
int t = b / a;
b -= t * a; swap(a, b);
u -= t * v; swap(u, v);
}
if (u < 0) u += P;
return u;
}
NTT
前面的函数都是必要的;faq
函数用于多项式快速幂(不对 x^n 取模),inverse
函数用于多项式求逆。
namespace NTT {
int bs = 1, rt = -1, mbs = -1;
vector<int> rev = {0, 1}, rts = {0, 1};
void init() {
int temp = P - 1; mbs = 0;
while (temp % 2 == 0) {
temp >>= 1; ++mbs;
}
rt = 2;
while (true) {
if (ksm(rt, 1 << mbs) == 1 && ksm(rt, 1 << (mbs - 1)) != 1) {
break;
}
++rt;
}
}
void ensure_base(int nbase) {
if (mbs == -1) init();
if (nbase <= bs) return;
assert(nbase <= mbs);
rev.resize(1 << nbase);
for (int i = 0; i < 1 << nbase; ++i) {
rev[i] = rev[i >> 1] >> 1 | (i & 1) << (nbase - 1);
}
rts.resize(1 << nbase);
while (bs < nbase) {
int z = ksm(rt, 1 << (mbs - 1 - bs));
for (int i = 1 << (bs - 1); i < 1 << bs; ++i) {
rts[i << 1] = rts[i];
rts[i << 1 | 1] = mul(rts[i], z);
}
++bs;
}
}
void dft(VI &a) {
int n = SZ(a), zeros = __builtin_ctz(n);
ensure_base(zeros);
int shift = bs - zeros;
for (int i = 0; i < n; ++i) {
if (i < rev[i] >> shift) {
swap(a[i], a[rev[i] >> shift]);
}
}
for (int i = 1; i < n; i <<= 1) {
for (int j = 0; j < n; j += i << 1) {
for (int k = 0; k < i; ++k) {
int x = a[j + k], y = mul(a[j + k + i], rts[i + k]);
a[j + k] = (x + y) % P;
a[j + k + i] = (x + P - y) % P;
}
}
}
}
VI multiply(VI a, VI b) {
int need = SZ(a) + SZ(b) - 1, nbase = 0;
while (1 << nbase < need) ++nbase;
ensure_base(nbase);
int sz = 1 << nbase;
a.resize(sz); b.resize(sz);
bool equal = (a == b);
dft(a);
if (equal) b = a; else dft(b);
int invsz = inv(sz);
for (int i = 0; i < sz; ++i) {
a[i] = mul(mul(a[i], b[i]), invsz);
}
reverse(a.begin() + 1, a.end());
dft(a); a.resize(need);
return a;
}
VI faq(vector<int> a, int m) {
int need = (SZ(a) - 1) * m + 1, nbase = 0;
while (1 << nbase < need) ++nbase;
ensure_base(nbase);
int sz = 1 << nbase;
a.resize(sz);
dft(a);
int invsz = inv(sz);
for (int i = 0; i < sz; ++i) {
a[i] = mul((int)ksm(a[i], m), invsz);
}
reverse(a.begin() + 1, a.end());
dft(a); a.resize(need);
return a;
}
VI inverse(VI a) {
int n = SZ(a), m = (n + 1) >> 1;
if (n == 1) {
return vector<int>(1, inv(a[0]));
}
else {
vector<int> b = inverse(vector<int>(a.begin(), a.begin() + m));
int need = n << 1, nbase = 0;
while (1 << nbase < need) ++nbase;
ensure_base(nbase);
int sz = 1 << nbase;
a.resize(sz); b.resize(sz);
dft(a); dft(b);
int invsz = inv(sz);
for (int i = 0; i < sz; ++i) {
a[i] = mul(mul(P + 2 - mul(a[i], b[i]), b[i]), invsz);
}
reverse(a.begin() + 1, a.end());
dft(a); a.resize(n);
return a;
}
}
}
using NTT::multiply;
using NTT::inverse;
各种函数的封装
VI& operator += (VI &a, const VI &b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < SZ(b); ++i) {
add(a[i], b[i]);
}
return a;
}
VI operator + (const VI &a, const VI &b) {
VI c = a;
return c += b;
}
VI& operator -= (VI &a, const VI &b) {
if (a.size() < b.size()) a.resize(b.size());
for (int i = 0; i < SZ(b); i++) {
sub(a[i], b[i]);
}
return a;
}
VI operator - (const VI &a, const VI &b) {
VI c = a;
return c -= b;
}
VI& operator *= (VI &a, const VI &b) {
if (min(a.size(), b.size()) < 128) {
vector<int> c = a;
a.assign(a.size() + b.size() - 1, 0);
for (int i = 0; i < SZ(c); ++i) {
for (int j = 0; j < SZ(b); ++j) {
add(a[i + j], mul(c[i], b[j]));
}
}
}
else a = multiply(a, b);
return a;
}
VI operator * (const VI &a, const VI &b) {
VI c = a;
return c *= b;
}
VI& operator /= (VI &a, const VI &b) {
int n = SZ(a), m = SZ(b);
if (n < m) a.clear();
else {
VI c = b;
reverse(a.begin(), a.end());
reverse(c.begin(), c.end());
c.resize(n - m + 1);
a *= inverse(c);
a.erase(a.begin() + n - m + 1, a.end());
reverse(a.begin(), a.end());
}
return a;
}
VI operator / (const VI &a, const VI &b) {
vector<int> c = a;
return c /= b;
}
VI& operator %= (VI &a, const VI &b) {
int n = SZ(a), m = SZ(b);
if (n >= m) {
VI c = (a / b) * b;
a.resize(m - 1);
for (int i = 0; i < m - 1; ++i) {
sub(a[i], c[i]);
}
}
return a;
}
VI operator % (const VI &a, const VI &b) {
vector<int> c = a;
return c %= b;
}
VI operator ^ (const VI &a, const int &m) {
return NTT::faq(a, m);
}
VI derivative(const VI &a) {
int n = SZ(a);
VI b(n - 1);
for (int i = 1; i < n; ++i) {
b[i - 1] = mul(a[i], i);
}
return b;
}
VI primitive(const VI &a) {
int n = SZ(a);
VI b(n + 1), invs(n + 1);
for (int i = 1; i <= n; ++i) {
invs[i] = i == 1 ? 1 : mul(P - P / i, invs[P % i]);
b[i] = mul(a[i - 1], invs[i]);
}
return b;
}
VI logarithm(const VI &a) {
VI b = primitive(derivative(a) * inverse(a));
b.resize(a.size());
return b;
}
VI exponent(const VI &a) {
VI b(1, 1);
while (b.size() < a.size()) {
VI c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
add(c[0], 1);
VI oldb = b;
b.resize(b.size() << 1);
c -= logarithm(b);
c *= oldb;
for (int i = SZ(b) >> 1; i < SZ(b); ++i) {
b[i] = c[i];
}
}
b.resize(a.size());
return b;
}
VI power(const VI &a, int m) {
int n = SZ(a), p = -1;
VI b(n);
for (int i = 0; i < n; ++i) {
if (a[i]) {
p = i; break;
}
}
if (p == -1) {
b[0] = !m; return b;
}
if ((long long) m * p >= n) return b;
int mu = ksm(a[p], m), di = inv(a[p]);
VI c(n - m * p);
for (int i = 0; i < n - m * p; ++i) {
c[i] = mul(a[i + p], di);
}
c = logarithm(c);
for (int i = 0; i < n - m * p; ++i) {
c[i] = mul(c[i], m);
}
c = exponent(c);
for (int i = 0; i < n - m * p; ++i) {
b[i + m * p] = mul(c[i], mu);
}
return b;
}
VI sqrt(const VI &a) {
vector<int> b(1, 1);
while (b.size() < a.size()) {
VI c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
VI oldb = b;
b.resize(b.size() << 1);
c *= inverse(b);
for (int i = SZ(b) >> 1; i < SZ(b); ++i) {
b[i] = mul(c[i], (P + 1) >> 1);
}
}
b.resize(a.size());
return b;
}
VI multiply_all(int l, int r, vector<VI> &all) {
if (l > r) return VI();
else if (l == r) return all[l];
else {
int y = (l + r) >> 1;
return multiply_all(l, y, all) * multiply_all(y + 1, r, all);
}
}
VI evaluate(const VI &f, const VI &x) {
int n = SZ(x);
if (!n) return VI();
vector<VI> up(n * 2);
for (int i = 0; i < n; ++i) {
up[i + n] = VI{(P - x[i]) % P, 1};
}
for (int i = n - 1; i; --i) {
up[i] = up[i << 1] * up[i << 1 | 1];
}
vector<VI> down(n * 2);
down[1] = f % up[1];
for (int i = 2; i < n * 2; ++i) {
down[i] = down[i >> 1] % up[i];
}
VI y(n);
for (int i = 0; i < n; ++i) {
y[i] = down[i + n][0];
}
return y;
}
VI interpolate(const VI &x, const VI &y) {
int n = SZ(x);
vector<VI> up(n * 2);
for (int i = 0; i < n; ++i) {
up[i + n] = VI{(P - x[i]) % P, 1};
}
for (int i = n - 1; i; --i) {
up[i] = up[i << 1] * up[i << 1 | 1];
}
VI a = evaluate(derivative(up[1]), x);
for (int i = 0; i < n; ++i) {
a[i] = mul(y[i], inv(a[i]));
}
vector<VI> down(n * 2);
for (int i = 0; i < n; ++i) {
down[i + n] = VI(1, a[i]);
}
for (int i = n - 1; i; --i) {
down[i] = down[i << 1] * up[i << 1 | 1] + down[i << 1 | 1] * up[i << 1];
}
return down[1];
}
函数使用方法
多项式除法与取模:给定一个 n 次多项式 F(x) 和一个 m 次多项式 G(x) ,请求出多项式 Q(x), R(x),满足以下条件:Q(x) 次数为 n-m,R(x) 次数小于 m; F(x) = Q(x) * G(x) + R(x)
多项式开根:给定一个n−1次多项式A(x),求一个在\bmod\ x^n意义下的多项式B(x),使得B^2(x) \equiv A(x) \ (\bmod\ x^n)。保证a_0 = 1.
多项式快速幂:给定一个n-1次多项式A(x),求一个在\bmod\ x^n意义下的多项式B(x),使得B(x) \equiv A^k(x) \ (\bmod\ x^n)。保证a_0 = 1.
多项式 exp:给出 n−1 次多项式 A(x),求一个 \bmod{x^n} 下的多项式 B(x),满足 B(x) \equiv \text e^{A(x)}。系数对 998244353 取模。保证 a_0 = 0.
多项式 ln:给出 n-1 次多项式 A(x),求一个 \bmod{x^n}mod 下的多项式 B(x),满足 B(x) \equiv \ln A(x). 在 \text{mod } 998244353 下进行.保证 a_0 = 1.
多项式多点求值:给定一个 n 次多项式 f(x) ,现在请你对于 i \in [1,m] ,求出 f(a_i)。第一行两个正整数 n,m 表示多项式的次数及你要求的点值的数量。第二行 n+1个非负整数,由低到高地给出多项式的系数。第三行 m 个非负整数,表示 a_i
多项式快速插值:给出 n 个点 (x_i, y_i),求一个 x-1 次的多项式 f(x),使得 f(x_i)\equiv y_i\pmod{998244353}.
VI ans, a, b;
int main()
{
//多项式乘法
// int n, m; scanf("%d%d", &n, &m);
// for(int i = 0, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
// for(int i = 0, j; i <= m; i++) {scanf("%d", &j); b.push_back(j);}
// ans = a * b;
// print(ans);
//多项式求逆
// int n; scanf("%d", &n);
// for(int i = 1, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
// ans = inverse(a);
// print(ans);
//多项式log
// int n; scanf("%d", &n);
// for(int i = 1, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
// ans = logarithm(a);
// print(ans);
//多项式快速幂(对x^n取模)
// auto getmod = [&] (int Mo) mutable -> int
// {
// long long b = 0; char c;
// while(!isdigit(c = getchar()));
// for(; isdigit(c); c = getchar())
// {
// b = b * 10 + c - '0';
// if(b >= mod) b %= mod;
// }
// return int(b);
// };
// int n, m; scanf("%d", &n); m = getmod(998244353);
// a.resize(n);
// for(int i = 0; i < n; i++) scanf("%d", &a[i]);
// ans = power(a, m);
// print(ans);
//多项式快速幂(不取模)
// ans = a ^ m;
//多项式除法和取模
// int n, m; scanf("%d%d", &n, &m);
// for(int i = 0, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
// for(int i = 0, j; i <= m; i++) {scanf("%d", &j); b.push_back(j);}
// VI ans1 = a / b, ans2 = a % b;
// print(ans1);
// print(ans2);
//多项式开根号
// int n; scanf("%d", &n);
// for(int i = 1, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
// ans = sqrt(a);
// print(ans);
//多项式对数函数
// int n; scanf("%d", &n);
// for(int i = 1, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
// ans = logarithm(a);
// print(ans);
//多项式指数函数
// int n; scanf("%d", &n);
// for(int i = 1, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
// ans = exponent(a);
// print(ans);
//多项式多点求值
// int n, m; scanf("%d%d", &n, &m);
// a.resize(n+1); b.resize(m);
// for(int i = 0; i <= n; i++) scanf("%d", &a[i]);
// for(int i = 0; i < m; i++) scanf("%d", &b[i]);
// ans = evaluate(a, b);
// for(int i = 0; i < m; i++) printf("%d\n", ans[i]);
//多项式快速插值
// int n;
// scanf("%d", &n);
// a.resize(n); b.resize(n);
// for(int i = 0; i < n; i++) scanf("%d%d", &a[i], &b[i]);
// ans = interpolate(a, b);
// print(ans);
return 0;
}