问题描述
给 m 个区间 [l, r] 以及一个数 x,每次在区间 [l, r] 上的每一个点放置一个物品 x,问:
- $[1, n]$ 中哪个点的物品种类最多?
- $[1, n]$ 中每个点数量最多的物品是哪种?
- ……
解题思路:
首先利用差分,将序列操作变成区间操作:在 l 上放置一个二元组标记 <x, 1>,在 r 上放置一个二元组标记 <x, -1>。
然后建立权值线段树,维护所有的 x。我们从 1 至 n 开始遍历,每次把当前位置的标记全部加入线段树中。由于差分的查询对应前缀和,当前线段树根结点所维护的答案即为结果。
例题:
题目大意:
给一颗 n 个点的树,每次给 x 到 y 的树上路径的每一个点放置一个物品 z,问从 1 到 n 的每个点中,物品个数最多的是哪种(一样则取最小)。
解题思路:
首先进行树链剖分,将树上路径变成 O(logn) 条链。此问题变转化为上面的问题。
这里的权值线段树维护区间最大值、区间最大值的位置在哪里。
#include <bits/stdc++.h>
using namespace std;
typedef pair<int, int> pii;
#define mp make_pair
const int N = 1e5 + 50;
int n, m, o;
int ans[N];
int head[N], tot;
struct Edge {
int nex, to;
}e[N << 1];
struct Node {
int dep, fa, sz;
int mson, top;
}nd[N];
void add(int a, int b) {
e[++tot] = (Edge) {head[a], b};
head[a] = tot;
}
void dfs1(int rt, int fa, int dep) {
nd[rt].dep = dep; nd[rt].fa = fa; nd[rt].sz = 1;
int mson = -1;
for (int i = head[rt]; i; i = e[i].nex) {
int to = e[i].to;
if (to == fa) continue;
dfs1(to, rt, dep + 1);
nd[rt].sz += nd[to].sz;
if (nd[to].sz > mson) {
mson = nd[to].sz;
nd[rt].mson = to;
}
}
}
int cnt, id[N], di[N];
void dfs2(int rt, int tp) {
id[rt] = ++cnt; di[cnt] = rt;
nd[rt].top = tp;
if (!nd[rt].mson) return;
dfs2(nd[rt].mson, tp);
for (int i = head[rt]; i; i = e[i].nex) {
int to = e[i].to;
if (to == nd[rt].fa || to == nd[rt].mson) continue;
dfs2(to, to);
}
}
vector<pii> t[N];
#define topx nd[x].top
#define topy nd[y].top
void lca(int x, int y, int z) {
while (topx != topy) {
if (nd[topx].dep > nd[topy].dep) swap(x, y);
t[id[topy]].push_back(mp(z, 1));
t[id[y] + 1].push_back(mp(z, -1));
y = nd[topy].fa;
}
if (nd[x].dep > nd[y].dep) swap(x, y);
t[id[x]].push_back(mp(z, 1));
t[id[y] + 1].push_back(mp(z, -1));
}
#define ls (rt<<1)
#define rs (rt<<1|1)
#define mid ((l+r)>>1)
struct Segment_Tree {
int id, val;
}a[N << 2];
void pushup(int rt) {
a[rt].val = max(a[ls].val, a[rs].val);
a[rt].id = (a[rt].val == a[ls].val ? a[ls].id : a[rs].id);
}
void add(int rt, int l, int r, int id, int val) {
if (l == r) {
a[rt].val += val;
a[rt].id = (a[rt].val ? id : 0);
return;
}
if (id <= mid) add(ls, l, mid, id, val);
else add(rs, mid + 1, r, id, val);
pushup(rt);
}
int main(int argc, const char * argv[]) {
scanf("%d%d", &n, &m);
for (int i = 1, a, b; i < n; i++) {
scanf("%d%d", &a, &b);
add(a, b); add(b, a);
}
dfs1(1, 0, 1); dfs2(1, 1);
for (int i = 1, a, b, c; i <= m; i++) {
scanf("%d%d%d", &a, &b, &c);
o = max(o, c);
lca(a, b, c);
}
for (int i = 0; i <= n; i++) {
for (auto x : t[i]) {
add(1, 1, o, x.first, x.second);
}
ans[di[i]] = a[1].id;
}
for (int i = 1; i <= n; i++) {
printf("%d\n", ans[i]);
}
return 0;
}