点差分
例题:P3128 [USACO15DEC]Max Flow P
题目大意:
给 k 条树上路径,求哪个点被经过的次数最多、被经过了多少次。
解题思路:
树上差分算法。对于每条路径 (a,b),对 num[a]++
,num[b]++
,num[lca(a, b)]--
,num[fa[lca(a, b)]]--
。
这样再做一次 dfs 后,发现从 a 到 b 路径上的每个点都加 1,而 lca(a, b)
的父亲节点不变。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 5e4 + 50;
int n, k, ans;
int num[N], cnt[N];
// 树上差分数组,经过这个点的路径数
int head[N], tot;
struct Edge {
int nex, to;
}e[N << 1];
void add(int a, int b) {
e[++tot] = (Edge) {head[a], b};
head[a] = tot;
}
int lg[N], fa[N][22], dep[N];
void init() {
for (int i = 1; i < N; i++) {
lg[i] = lg[i-1] + ((1 << lg[i-1]) == i);
}
}
void dfs(int u, int F) {
dep[u] = dep[F] + 1;
fa[u][0] = F;
for(int i = 1; (1 << i) <= dep[u]; i++) {
fa[u][i] = fa[fa[u][i-1]][i-1];
}
for(int i = head[u]; i; i = e[i].nex) {
int to = e[i].to;
if(to != F) dfs(to, u);
}
}
int lca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
while (dep[x] > dep[y]) x = fa[x][lg[dep[x] - dep[y]] - 1];
if (x == y) return x;
for (int i = lg[dep[x]] - 1; i >= 0; i--) {
if (fa[x][i] != fa[y][i]) {
x = fa[x][i]; y = fa[y][i];
}
}
return fa[x][0];
}
void add2(int x, int y, int w = 1) {
num[x] += w; num[y] += w;
int a = lca(x, y), b = fa[a][0];
num[a] -= w; num[b] -= w;
}
void dfs2(int u, int F) {
cnt[u] = num[u];
for (int i = head[u]; i; i = e[i].nex) {
int to = e[i].to;
if (to == F) {
continue;
}
dfs2(to, u);
cnt[u] += cnt[to];
}
}
int main(int argc, const char * argv[]) {
init();
scanf("%d%d", &n, &k);
for (int i = 1, a, b; i < n; i++) {
scanf("%d%d", &a, &b);
add(a, b); add(b, a);
}
dfs(1, 0);
for (int i = 1, a, b; i <= k; i++) {
scanf("%d%d", &a, &b);
add2(a, b);
}
dfs2(1, 0);
for (int i = 1; i <= n; i++) {
ans = max(ans, cnt[i]);
}
printf("%d\n", ans);
return 0;
}
例题2:Gym – 102012G Rikka with Intersections of Paths
题目大意:
给 m 条树上路径。问有多少种方案,从中选择 k 条,这 k 条路径至少有一个公共点。
问题分析:
有一个结论:
一个树上任意两条路径如果有交点的话,那么这些交点中肯定有一个为这两条路径中的一条路径两端点的 lca;其余交点均不为 lca。
衍生出这个结论:
k 条树上路径如果有多个交点,有且只有一个是其中一条路径的 lca。
考虑每一个点的贡献。我们可以用树上差分得到每个点被经过的次数,同时预处理出每个点作为 lca 的次数。由上面的结论,计算每个点的贡献时,我们只计算它至少为某一条路径的 lca 时的次数即可。
ll num[N], cnt[N], tim[N];
// 树上差分数组,经过这个点的路径数,LCA是这个点的路径数
/*
略去组合数、LCA、树上差分模版
*/
int main(int argc, const char * argv[]) {
cin >> T;
init_all();
while (T--) {
ll ans = 0;
cin >> n >> m >> k;
for (int i = 1, a, b; i < n; i++) {
cin >> a >> b;
add(a, b); add(b, a);
}
dfs(1, 0);
for (int i = 1, a, b; i <= m; i++) {
cin >> a >> b;
num[a]++; num[b]++;
int c = lca(a, b), d = fa[c][0];
num[c]--; num[d]--;
tim[c]++;
}
dfs2(1, 0);
for (int i = 1; i <= n; i++) {
(ans += C(cnt[i], k) - C(cnt[i] - tim[i], k) + P) %= P;
}
cout << ans << "\n";
clear();
}
return 0;
}
边差分
对于边权的树上差分,我们考虑将边权赋值给儿子节点;同时差分数组也应稍作修改:
num[x]++; num[y]++
num[lca(x, y)] -= 2
例题:P2680 运输计划
题目大意:
给一颗树,每条边有边权;给 m 条树上路径。现在可以把一条树边权值变为 0,问这 m 条路径的最大值,最小为多少(即最小化最大值)。
解题思路:
注意到题目要求最小化的最大值,显然答案满足单调性,因此考虑二分答案。
首先进行树链剖分、预处理出 m 条路径的 lca 以及路径长度 len。在二分答案时,假如正在 check(x)
:
我们挑选出 m 条路径中所有长度大于 x 的路径(假设有 cnt 条),那么删除的边一定是这 cnt 条路径的公共边。对这 cnt 条路径做边差分,可以得到每条边被经过了多少次。对于所有经过了 cnt 次的边,如果其权值大于等于 maxdis-x,则 check 成功。
编写代码要注意,在树链剖分 dfs1 时,将边权赋值给儿子的点权;同时处理出所有点到根结点的举例 dis。计算树上两点距离时,有:
len(x, y) = dis[x] + dis[y] - 2 * dis[lca(x, y)]
#include <bits/stdc++.h>
using namespace std;
const int N = 3e5 + 50;
int n, m;
int head[N], tot;
struct Edge {
int nex, to, w;
}e[N << 1];
struct Node {
int dep, fa, sz;
int mson, top;
int val, dis;
}nd[N];
struct Query {
int x, y;
int lca, len;
}q[N];
void add(int a, int b, int c) {
e[++tot] = (Edge) {head[a], b, c};
head[a] = tot;
}
void dfs1(int rt, int fa, int dep) {
nd[rt].dep = dep; nd[rt].fa = fa; nd[rt].sz = 1;
int mson = -1;
for (int i = head[rt]; i; i = e[i].nex) {
int to = e[i].to;
if (to == fa) continue;
nd[to].val = e[i].w;
nd[to].dis = nd[rt].dis + e[i].w;
dfs1(to, rt, dep + 1);
nd[rt].sz += nd[to].sz;
if (nd[to].sz > mson) {
mson = nd[to].sz;
nd[rt].mson = to;
}
}
}
void dfs2(int rt, int tp) {
nd[rt].top = tp;
if (!nd[rt].mson) return;
dfs2(nd[rt].mson, tp);
for (int i = head[rt]; i; i = e[i].nex) {
int to = e[i].to;
if (to == nd[rt].fa || to == nd[rt].mson) continue;
dfs2(to, to);
}
}
#define topx nd[x].top
#define topy nd[y].top
int lca(int x, int y) {
while (topx != topy) {
if(nd[topx].dep < nd[topy].dep) y = nd[topy].fa;
else x = nd[topx].fa;
}
return nd[x].dep < nd[y].dep ? x : y;
}
int num[N];
void dfs3(int u, int fa) {
for (int i = head[u]; i; i = e[i].nex) {
int to = e[i].to;
if (to == fa) continue;
dfs3(to, u);
num[u] += num[to];
}
}
bool check(int x) {
int cnt = 0, maxdis = 0;
memset(num, 0, sizeof(num));
for (int i = 1; i <= m; i++) {
if (q[i].len <= x) continue;
num[q[i].x]++; num[q[i].y]++;
num[q[i].lca] -= 2;
maxdis = max(maxdis, q[i].len);
cnt++;
}
dfs3(1, 0);
for (int i = 1; i <= n; i++) {
if (num[i] == cnt) {
if (maxdis - nd[i].val <= x) return true;
}
}
return false;
}
int main(int argc, const char * argv[]) {
int l = -1, r = 0;
scanf("%d%d", &n, &m);
for (int i = 1, a, b, c; i < n; i++) {
scanf("%d%d%d", &a, &b, &c);
add(a, b, c); add(b, a, c);
}
dfs1(1, 0, 1); dfs2(1, 1);
for (int i = 1; i <= m; i++) {
scanf("%d%d", &q[i].x, &q[i].y);
q[i].lca = lca(q[i].x, q[i].y);
q[i].len = nd[q[i].x].dis + nd[q[i].y].dis - 2 * nd[q[i].lca].dis;
r = max(r, q[i].len);
}
while (l + 1 < r) {
int mid = (l + r) >> 1;
if (check(mid)) r = mid;
else l = mid;
}
printf("%d\n", r);
return 0;
}