第四节 任意模数NTT(MTT+求逆)

有了多项式求逆和多项式乘法,以此为基础,可以类似前面 NTT 的做法写出各种函数。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll,ll> pll;
typedef pair<int, int> pii;
typedef vector<int> VI;
#define SZ(o) (int(o.size()))
const int P = 1e9 + 7;

pii exgcd(const ll x,const ll y) {
    if (!y) return{1, 0};
    pll c = exgcd(y, x%y);
    return {c.second, c.first - (x/y) * c.second};
}

namespace FFT {
    const double pi = acos(-1.0);
    struct complex {
        double r, i;
        complex(double x = 0, double y = 0) : r(x), i(y) {}
        complex operator+ (const complex &b) const {
            return complex(r + b.r, i + b.i);
        }
        complex operator- (const complex &b) const {
            return complex(r - b.r, i - b.i);
        }
        complex operator* (const complex &b) const {
            return complex(r * b.r - i * b.i, r * b.i + i * b.r);
        }
        complex operator! () const {return complex(r, -i);}
    };

    int bs = 1;
    VI rev = {0, 1};
    vector<complex> rts = {{0, 0}, {1, 0}};

    void ensure_base(int nbs) {
        if (nbs <= bs) return;
        rev.resize(1 << nbs);
        for (int i = 0; i < (1 << nbs); i++) {
            rev[i] = (rev[i>>1]>>1) + ((i&1)<<(nbs-1));
        }
        rts.resize(1 << nbs);
        for(; bs < nbs; bs++) {
            double ag = 2 * pi / (1 << (bs + 1));
            for (int i = 1 << (bs - 1); i < (1 << bs); i++) {
                rts[i << 1] = rts[i];
                double agi = ag * (2 * i + 1 - (1 << bs));
                rts[(i << 1) + 1] = complex(cos(agi), sin(agi));
            }
        }
    }
    void FFT(vector<complex> &a, int n = -1) {
        if (n == -1) n = SZ(a);
        int zeros = __builtin_ctz(n);
        ensure_base(zeros);
        int shift = bs - zeros;
        for (int i = 0; i < n; i++) if (i < (rev[i] >> shift)) {
            swap(a[i], a[rev[i] >> shift]);
        }
        for (int k = 1; k < n; k <<= 1) {
            for (int i = 0; i < n; i += 2 * k) {
                for(int j = 0; j < k; j++) {
                    complex z = a[i + j + k] * rts[j + k];
                    a[i + j + k] = a[i + j] - z;
                    a[i + j] = a[i + j] + z;
                }
            }
        }
    }

    vector<complex> fa, fb;
    VI multiply(const VI &a, const VI &b) {
        int nd = SZ(a) + SZ(b) - 1;
        int nbs = nd > 1 ? 32 - __builtin_clz(nd - 1) : 0;
        ensure_base(nbs);
        int sz = 1 << nbs;
        if (sz > (int) fa.size()) fa.resize(sz);
        for (int i = 0; i < sz; i++) {
            int x = (i < (int) a.size() ? a[i] : 0);
            int y = (i < (int) b.size() ? b[i] : 0);
            fa[i] = complex(x, y);
        }
        FFT(fa, sz);
        complex r(0, -0.25 / sz);
        for (int i = 0; i <= (sz >> 1); i++) {
            int j = (sz - i) & (sz - 1);
            complex z = (fa[j] * fa[j] - !(fa[i] * fa[i])) * r;
            if (i != j) {
                fa[j] = (fa[i] * fa[i] - !(fa[j] * fa[j])) * r;
            }
            fa[i] = z;
        }
        FFT(fa, sz);
        vector<int> res(nd);
        for (int i = 0; i < nd; i++) res[i] = fa[i].r + 0.5;
        return res;
    }
    VI multiply_mod(const VI &a, const VI &b, int m, int eq = 0) {
        int nd = SZ(a) + SZ(b) - 1;
        int nbs = nd > 1 ? 32 - __builtin_clz(nd - 1) : 0;
        ensure_base(nbs);
        int sz = 1 << nbs;
        if (sz > SZ(fa)) fa.resize(sz);
        for (int i = 0; i < SZ(a); i++) {
            int x = (a[i] % m + m) % m;
            fa[i] = complex(x & ((1 << 15) - 1), x >> 15);
        }
        fill(fa.begin() + a.size(), fa.begin() + sz, complex(0, 0));
        FFT(fa, sz);
        if (sz > SZ(fb)) fb.resize(sz);
        if (eq) copy(fa.begin(), fa.begin() + sz, fb.begin());
        else {
            for (int i = 0; i < SZ(b); i++) {
                int x = (b[i] % m + m) % m;
                fb[i] = complex(x & ((1 << 15) - 1), x >> 15);
            }
            fill(fb.begin() + b.size(), fb.begin() + sz, complex(0, 0));
            FFT(fb, sz);
        }
        double rto = 0.25 / sz;
        complex r2(0, -1), r3(rto, 0), r4(0, -rto), r5(0, 1);
        for (int i = 0; i <= (sz >> 1); i++) {
            int j = (sz - i) & (sz - 1);
            complex a1 = (fa[i] + !fa[j]),
                    a2 = (fa[i] - !fa[j]) * r2,
                    b1 = (fb[i] + !fb[j]) * r3,
                    b2 = (fb[i] - !fb[j]) * r4;
            if (i != j) {
                complex c1 = (fa[j] + !fa[i]),
                        c2 = (fa[j] - !fa[i]) * r2,
                        d1 = (fb[j] + !fb[i]) * r3,
                        d2 = (fb[j] - !fb[i]) * r4;
                fa[i] = c1 * d1 + c2 * d2 * r5;
                fb[i] = c1 * d2 + c2 * d1;
            }
            fa[j] = a1 * b1 + a2 * b2 * r5;
            fb[j] = a1 * b2 + a2 * b1;
        }
        FFT(fa, sz); FFT(fb, sz);
        vector<int> res(nd);
        for (int i = 0; i < nd; i++) {
            ll aa = fa[i].r + 0.5;
            ll bb = fb[i].r + 0.5, cc = fa[i].i + 0.5;
            res[i] = (aa + ((bb % m) << 15) + ((cc % m) << 30)) % m;
        }
        return res;
    }
    VI getInv(const VI &a) {
        if (SZ(a) == 1) {
            const int inv = exgcd(a[0], P).first;
            return vector<int>(1, inv < 0 ? inv + P : inv);
        }
        const int n = SZ(a), sb = (n + 1) / 2;
        VI b; b.assign(a.begin(), a.begin() + sb);
        b = getInv(b);
        VI c = multiply_mod(b, b, P);
        c.resize(n);
        c = multiply_mod(a, c, P);
        b.resize(n); c.resize(n);
        for(int i = 0; i < n; i++) c[i] = (2ll * b[i] - c[i] + P) % P;
        return c;
    }
}
namespace Poly {
    VI operator * (VI a, VI b) {
        if (a == b) return FFT::multiply_mod(a, b, P, 1);
        else return FFT::multiply_mod(a, b, P);
    }
    VI operator ~ (VI a) {
        return FFT::getInv(a);
    }
    void print(VI v) {
        int sz = SZ(v);
        for(int i = 0; i < sz - 1; i++) printf("%d ", v[i]);
        printf("%d\n", v[sz-1]);
    }
}
using namespace Poly;

VI a, b, ans;
int main() {
    int n;
    scanf("%d", &n);
    VI a(n), v;
    for(int &x: a) scanf("%d", &x);
    v = ~a; print(v);
    return 0;
}
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇
下一篇