普通多项式乘法
#include <bits/stdc++.h>
using namespace std;
#define SZ(o) (int)o.size()
namespace FFT {
typedef double db;
typedef vector<double> vdb;
const int N = 3e6 + 10;
const db pi = acos(-1);
int lim, l, r[N];
struct Complex {
db r, i;
Complex(db r = 0, db i = 0) : r(r), i(i) {}
Complex operator + (const Complex &x) const {
return Complex(r + x.r, i + x.i);
}
Complex operator - (const Complex &x) const {
return Complex(r - x.r, i - x.i);
}
Complex operator * (const Complex &x) const {
return Complex(r * x.r - i * x.i, i * x.r + r * x.i);
}
};
void DFT(vector<Complex> &A,int type) {
for (int i = 0; i < lim; i++) if (i < r[i]) {
swap(A[i], A[r[i]]);
}
for (int mid = 1; mid < lim; mid <<= 1) {
Complex Wn(cos(pi / mid), type * sin(pi / mid));
for (int R = mid << 1, j = 0; j < lim; j += R) {
Complex w(1, 0);
for (int k = 0; k < mid; k++, w = w * Wn) {
Complex x = A[j + k], y = w * A[j + mid + k];
A[j + k] = x + y;
A[j + mid + k] = x - y;
}
}
}
}
vector<Complex> A, B;
vdb mul(const vdb &a, const vdb &b) {
int n = SZ(a), m = SZ(b);
for (lim = 1, l = 0; lim <= n + m - 2; lim <<= 1, ++l);
for (int i = 0; i < lim; i++) {
r[i] = (r[i>>1]>>1) | ((i&1)<<(l-1));
}
A.resize(lim + 1);
for(int i = 0; i <= lim; i++) {
A[i].r = i < n ? a[i] : 0; A[i].i = i < m ? b[i] : 0;
}
DFT(A, 1);
for (int i = 0; i <= lim; i++) A[i] = A[i] * A[i];
DFT(A, -1);
vdb ret(n + m - 1);
for (int i = 0; i < n + m - 1; i++) {
ret[i] = A[i].i / lim / 2;
}
return ret;
}
}
using namespace FFT;
void print(vector<double> &a) {
int n = int(a.size());
for(int i = 0; i < n; i++) {
printf("%d%c", int(a[i] + 0.5), i == n - 1 ? '\n' : ' ');
}
}
vector<double> a, b, ans;
int main() {
int n, m;
scanf("%d%d", &n, &m); a.resize(n + 1); b.resize(m + 1);
for(int i = 0; i <= n; i++) scanf("%lf", &a[i]);
for(int i = 0; i <= m; i++) scanf("%lf", &b[i]);
ans = mul(a, b); print(ans);
return 0;
}
三次变两次
用下面的 mul
函数替代上面的。注意这种用法最好只用在 int
类型的 FFT 中,否则容易炸精度。
vdb mul(const vdb &a, const vdb &b) {
int n = SZ(a), m = SZ(b);
for (lim = 1, l = 0; lim <= n + m - 2; lim <<= 1, ++l);
for (int i = 0; i < lim; i++) {
r[i] = (r[i>>1]>>1) | ((i&1)<<(l-1));
}
A.resize(lim + 1);
for (int i = 0; i <= lim; i++) {
A[i].r = i < n ? a[i] : 0; A[i].i = i < m ? b[i] : 0;
}
DFT(A, 1);
for (int i = 0; i <= lim; i++) A[i] = A[i] * A[i];
DFT(A, -1);
vdb ret(n + m - 1);
for(int i = 0; i < n + m - 1; i++) {
ret[i] = A[i].i / lim / 2;
}
return ret;
}