第三节 多项式全家桶(NTT)

Contents

常用函数与定义

#include <bits/stdc++.h>
using namespace std;
#define SZ(x) ((int)(x).size())
typedef vector<int> VI;
typedef long long ll;
const int P = 998244353;

void print(VI a) {
    int n = int(a.size());
    for (int i = 0; i < n - 1; i++) printf("%d ", a[i]);
    printf("%d\n", a[n - 1]);
}
inline void add(int &x, int y) {
    x += y;
    if (x >= P) x -= P;
}
inline void sub(int &x, int y) {
    x -= y;
    if (x < 0) x += P;
}
inline int mul(int x, int y) {
    return 1LL * x * y % P;
}
int ksm(int x, int y) {
    int res = 1;
    for (; y; y >>= 1, x = mul(x, x)) {
        if (y & 1) res = mul(res, x);
    }
    return res;
}
inline int inv(int a) {
    a %= P;
    if (a < 0) a += P;
    int b = P, u = 0, v = 1;
    while (a) {
        int t = b / a;
        b -= t * a; swap(a, b);
        u -= t * v; swap(u, v);
    }
    if (u < 0) u += P;
    return u;
}

NTT

前面的函数都是必要的;faq 函数用于多项式快速幂(不对 x^n 取模),inverse 函数用于多项式求逆。

namespace NTT {
    int bs = 1, rt = -1, mbs = -1;
    vector<int> rev = {0, 1}, rts = {0, 1};

    void init() {
        int temp = P - 1; mbs = 0;
        while (temp % 2 == 0) {
            temp >>= 1; ++mbs;
        }
        rt = 2;
        while (true) {
            if (ksm(rt, 1 << mbs) == 1 && ksm(rt, 1 << (mbs - 1)) != 1) {
                break;
            }
            ++rt;
        }
    }
    void ensure_base(int nbase) {
        if (mbs == -1) init();
        if (nbase <= bs) return;
        assert(nbase <= mbs);
        rev.resize(1 << nbase);
        for (int i = 0; i < 1 << nbase; ++i) {
            rev[i] = rev[i >> 1] >> 1 | (i & 1) << (nbase - 1);
        }
        rts.resize(1 << nbase);
        while (bs < nbase) {
            int z = ksm(rt, 1 << (mbs - 1 - bs));
            for (int i = 1 << (bs - 1); i < 1 << bs; ++i) {
                rts[i << 1] = rts[i];
                rts[i << 1 | 1] = mul(rts[i], z);
            }
            ++bs;
        }
    }
    void dft(VI &a) {
        int n = SZ(a), zeros = __builtin_ctz(n);
        ensure_base(zeros);
        int shift = bs - zeros;
        for (int i = 0; i < n; ++i) {
            if (i < rev[i] >> shift) {
                swap(a[i], a[rev[i] >> shift]);
            }
        }
        for (int i = 1; i < n; i <<= 1) {
            for (int j = 0; j < n; j += i << 1) {
                for (int k = 0; k < i; ++k) {
                    int x = a[j + k], y = mul(a[j + k + i], rts[i + k]);
                    a[j + k] = (x + y) % P;
                    a[j + k + i] = (x + P - y) % P;
                }
            }
        }
    }

    VI multiply(VI a, VI b) {
        int need = SZ(a) + SZ(b) - 1, nbase = 0;
        while (1 << nbase < need) ++nbase;
        ensure_base(nbase);
        int sz = 1 << nbase;
        a.resize(sz); b.resize(sz);
        bool equal = (a == b);
        dft(a);
        if (equal) b = a; else dft(b);
        int invsz = inv(sz);
        for (int i = 0; i < sz; ++i) {
            a[i] = mul(mul(a[i], b[i]), invsz);
        }
        reverse(a.begin() + 1, a.end());
        dft(a); a.resize(need);
        return a;
    }

    VI faq(vector<int> a, int m) {
        int need = (SZ(a) - 1) * m + 1, nbase = 0;
        while (1 << nbase < need) ++nbase;
        ensure_base(nbase);
        int sz = 1 << nbase;
        a.resize(sz);
        dft(a);
        int invsz = inv(sz);
        for (int i = 0; i < sz; ++i) {
            a[i] = mul((int)ksm(a[i], m), invsz);
        }
        reverse(a.begin() + 1, a.end());
        dft(a); a.resize(need);
        return a;
    }

    VI inverse(VI a) {
        int n = SZ(a), m = (n + 1) >> 1;
        if (n == 1) {
            return vector<int>(1, inv(a[0]));
        }
        else {
            vector<int> b = inverse(vector<int>(a.begin(), a.begin() + m));
            int need = n << 1, nbase = 0;
            while (1 << nbase < need) ++nbase;
            ensure_base(nbase);
            int sz = 1 << nbase;
            a.resize(sz); b.resize(sz);
            dft(a); dft(b);
            int invsz = inv(sz);
            for (int i = 0; i < sz; ++i) {
                a[i] = mul(mul(P + 2 - mul(a[i], b[i]), b[i]), invsz);
            }
            reverse(a.begin() + 1, a.end());
            dft(a); a.resize(n);
            return a;
        }
    }
}
using NTT::multiply;
using NTT::inverse;

各种函数的封装

VI& operator += (VI &a, const VI &b) {
    if (a.size() < b.size()) {
        a.resize(b.size());
    }
    for (int i = 0; i < SZ(b); ++i) {
        add(a[i], b[i]);
    }
    return a;
}

VI operator + (const VI &a, const VI &b) {
    VI c = a;
    return c += b;
}

VI& operator -= (VI &a, const VI &b) {
    if (a.size() < b.size()) a.resize(b.size());
    for (int i = 0; i < SZ(b); i++) {
        sub(a[i], b[i]);
    }
    return a;
}

VI operator - (const VI &a, const VI &b) {
    VI c = a;
    return c -= b;
}

VI& operator *= (VI &a, const VI &b) {
    if (min(a.size(), b.size()) < 128) {
        vector<int> c = a;
        a.assign(a.size() + b.size() - 1, 0);
        for (int i = 0; i < SZ(c); ++i) {
            for (int j = 0; j < SZ(b); ++j) {
                add(a[i + j], mul(c[i], b[j]));
            }
        }
    }
    else a = multiply(a, b);
    return a;
}

VI operator * (const VI &a, const VI &b) {
    VI c = a;
    return c *= b;
}

VI& operator /= (VI &a, const VI &b) {
    int n = SZ(a), m = SZ(b);
    if (n < m) a.clear();
    else {
        VI c = b;
        reverse(a.begin(), a.end());
        reverse(c.begin(), c.end());
        c.resize(n - m + 1);
        a *= inverse(c);
        a.erase(a.begin() + n - m + 1, a.end());
        reverse(a.begin(), a.end());
    }
    return a;
}

VI operator / (const VI &a, const VI &b) {
    vector<int> c = a;
    return c /= b;
}

VI& operator %= (VI &a, const VI &b) {
    int n = SZ(a), m = SZ(b);
    if (n >= m) {
        VI c = (a / b) * b;
        a.resize(m - 1);
        for (int i = 0; i < m - 1; ++i) {
            sub(a[i], c[i]);
        }
    }
    return a;
}

VI operator % (const VI &a, const VI &b) {
    vector<int> c = a;
    return c %= b;
}

VI operator ^ (const VI &a, const int &m) {
    return NTT::faq(a, m);
}

VI derivative(const VI &a) {
    int n = SZ(a);
    VI b(n - 1);
    for (int i = 1; i < n; ++i) {
        b[i - 1] = mul(a[i], i);
    }
    return b;
}

VI primitive(const VI &a) {
    int n = SZ(a);
    VI b(n + 1), invs(n + 1);
    for (int i = 1; i <= n; ++i) {
        invs[i] = i == 1 ? 1 : mul(P - P / i, invs[P % i]);
        b[i] = mul(a[i - 1], invs[i]);
    }
    return b;
}

VI logarithm(const VI &a) {
    VI b = primitive(derivative(a) * inverse(a));
    b.resize(a.size());
    return b;
}

VI exponent(const VI &a) {
    VI b(1, 1);
    while (b.size() < a.size()) {
        VI c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
        add(c[0], 1);
        VI oldb = b;
        b.resize(b.size() << 1);
        c -= logarithm(b);
        c *= oldb;
        for (int i = SZ(b) >> 1; i < SZ(b); ++i) {
            b[i] = c[i];
        }
    }
    b.resize(a.size());
    return b;
}

VI power(const VI &a, int m) {
    int n = SZ(a), p = -1;
    VI b(n);
    for (int i = 0; i < n; ++i) {
        if (a[i]) {
            p = i; break;
        }
    }
    if (p == -1) {
        b[0] = !m; return b;
    }
    if ((long long) m * p >= n) return b;
    int mu = ksm(a[p], m), di = inv(a[p]);
    VI c(n - m * p);
    for (int i = 0; i < n - m * p; ++i) {
        c[i] = mul(a[i + p], di);
    }
    c = logarithm(c);
    for (int i = 0; i < n - m * p; ++i) {
        c[i] = mul(c[i], m);
    }
    c = exponent(c);
    for (int i = 0; i < n - m * p; ++i) {
        b[i + m * p] = mul(c[i], mu);
    }
    return b;
}

VI sqrt(const VI &a) {
    vector<int> b(1, 1);
    while (b.size() < a.size()) {
        VI c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
        VI oldb = b;
        b.resize(b.size() << 1);
        c *= inverse(b);
        for (int i = SZ(b) >> 1; i < SZ(b); ++i) {
            b[i] = mul(c[i], (P + 1) >> 1);
        }
    }
    b.resize(a.size());
    return b;
}

VI multiply_all(int l, int r, vector<VI> &all) {
    if (l > r) return VI();
    else if (l == r) return all[l];
    else {
        int y = (l + r) >> 1;
        return multiply_all(l, y, all) * multiply_all(y + 1, r, all);
    }
}

VI evaluate(const VI &f, const VI &x) {
    int n = SZ(x);
    if (!n) return VI();
    vector<VI> up(n * 2);
    for (int i = 0; i < n; ++i) {
        up[i + n] = VI{(P - x[i]) % P, 1};
    }
    for (int i = n - 1; i; --i) {
        up[i] = up[i << 1] * up[i << 1 | 1];
    }
    vector<VI> down(n * 2);
    down[1] = f % up[1];
    for (int i = 2; i < n * 2; ++i) {
        down[i] = down[i >> 1] % up[i];
    }
    VI y(n);
    for (int i = 0; i < n; ++i) {
        y[i] = down[i + n][0];
    }
    return y;
}

VI interpolate(const VI &x, const VI &y) {
    int n = SZ(x);
    vector<VI> up(n * 2);
    for (int i = 0; i < n; ++i) {
        up[i + n] = VI{(P - x[i]) % P, 1};
    }
    for (int i = n - 1; i; --i) {
        up[i] = up[i << 1] * up[i << 1 | 1];
    }
    VI a = evaluate(derivative(up[1]), x);
    for (int i = 0; i < n; ++i) {
        a[i] = mul(y[i], inv(a[i]));
    }
    vector<VI> down(n * 2);
    for (int i = 0; i < n; ++i) {
        down[i + n] = VI(1, a[i]);
    }
    for (int i = n - 1; i; --i) {
        down[i] = down[i << 1] * up[i << 1 | 1] + down[i << 1 | 1] * up[i << 1];
    }
    return down[1];
}

函数使用方法

多项式除法与取模:给定一个 n 次多项式 F(x) 和一个 m 次多项式 G(x) ,请求出多项式 Q(x), R(x),满足以下条件:Q(x) 次数为 n-mR(x) 次数小于 m; F(x) = Q(x) * G(x) + R(x)

多项式开根:给定一个n−1次多项式A(x),求一个在\bmod\ x^n意义下的多项式B(x),使得B^2(x) \equiv A(x) \ (\bmod\ x^n)。保证a_0 = 1.

多项式快速幂:给定一个n-1次多项式A(x),求一个在\bmod\ x^n意义下的多项式B(x),使得B(x) \equiv A^k(x) \ (\bmod\ x^n)。保证a_0 = 1.

多项式 exp:给出 n−1 次多项式 A(x),求一个 \bmod{x^n} 下的多项式 B(x),满足 B(x) \equiv \text e^{A(x)}。系数对 998244353 取模。保证 a_0 = 0.

多项式 ln:给出 n-1 次多项式 A(x),求一个 \bmod{x^n}mod 下的多项式 B(x),满足 B(x) \equiv \ln A(x). 在 \text{mod } 998244353 下进行.保证 a_0 = 1.

多项式多点求值:给定一个 n 次多项式 f(x) ,现在请你对于 i \in [1,m] ,求出 f(a_i)。第一行两个正整数 n,m 表示多项式的次数及你要求的点值的数量。第二行 n+1个非负整数,由低到高地给出多项式的系数。第三行 m 个非负整数,表示 a_i

多项式快速插值:给出 n 个点 (x_i, y_i),求一个 x-1 次的多项式 f(x),使得 f(x_i)\equiv y_i\pmod{998244353}.

VI ans, a, b;

int main()
{
    //多项式乘法
//    int n, m; scanf("%d%d", &n, &m);
//    for(int i = 0, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
//    for(int i = 0, j; i <= m; i++) {scanf("%d", &j); b.push_back(j);}
//    ans = a * b;
//    print(ans);
    //多项式求逆
//    int n; scanf("%d", &n);
//    for(int i = 1, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
//    ans = inverse(a);
//    print(ans);
    //多项式log
//    int n; scanf("%d", &n);
//    for(int i = 1, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
//    ans = logarithm(a);
//    print(ans);
    //多项式快速幂(对x^n取模)
//    auto getmod = [&] (int Mo) mutable -> int
//    {
//        long long b = 0; char c;
//        while(!isdigit(c = getchar()));
//        for(; isdigit(c); c = getchar())
//        {
//            b = b * 10 + c - '0';
//            if(b >= mod) b %= mod;
//        }
//        return int(b);
//    };
//    int n, m; scanf("%d", &n); m = getmod(998244353);
//    a.resize(n);
//    for(int i = 0; i < n; i++) scanf("%d", &a[i]);
//    ans = power(a, m);
//    print(ans);
    //多项式快速幂(不取模)
//    ans = a ^ m;
    //多项式除法和取模
//    int n, m; scanf("%d%d", &n, &m);
//    for(int i = 0, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
//    for(int i = 0, j; i <= m; i++) {scanf("%d", &j); b.push_back(j);}
//    VI ans1 = a / b, ans2 = a % b;
//    print(ans1);
//    print(ans2);
    //多项式开根号
//    int n; scanf("%d", &n);
//    for(int i = 1, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
//    ans = sqrt(a);
//    print(ans);
    //多项式对数函数
//    int n; scanf("%d", &n);
//    for(int i = 1, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
//    ans = logarithm(a);
//    print(ans);
    //多项式指数函数
//    int n; scanf("%d", &n);
//    for(int i = 1, j; i <= n; i++) {scanf("%d", &j); a.push_back(j);}
//    ans = exponent(a);
//    print(ans);
    //多项式多点求值
//    int n, m; scanf("%d%d", &n, &m);
//    a.resize(n+1); b.resize(m);
//    for(int i = 0; i <= n; i++) scanf("%d", &a[i]);
//    for(int i = 0; i < m; i++) scanf("%d", &b[i]);
//    ans = evaluate(a, b);
//    for(int i = 0; i < m; i++) printf("%d\n", ans[i]);
    //多项式快速插值
//    int n;
//    scanf("%d", &n);
//    a.resize(n); b.resize(n);
//    for(int i = 0; i < n; i++) scanf("%d%d", &a[i], &b[i]);
//    ans = interpolate(a, b);
//    print(ans);
    return 0;
}
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇