点分治是大规模处理树上路径问题的工具。大意是找到一个点,递归统计其所有子树的答案,然后利用容斥原理或其它方式合并答案,最后得到整棵树的答案。
步骤
点分治的大体代码框架是这样的:首先对整棵树找重心点 $rt$,然后从 $rt$ 开始,向下递归求解。
首先,统计以 $rt$ 为根的树的答案,再对 $rt$ 的每棵子树 $ch_i$ 求解,统计得到以 $ch_i$ 为根的子树答案;根据容斥原理,$ans_{rt} - \Sigma _{i=0}^{chsize}ans_i$ 即为整棵树的答案。
第一步:找重心
重心的定义:在一棵树上找一个点,使得该点所有的子树中,最大的子树节点数最少,那么这个点就是这棵树的重心。
找重心的意义:点分治的过程中,要找到一个起始点统计答案,并且从这个点向子树递归。那么选择合理的起始点是很有必要的。
如果数据呈链状结构的时候,头铁地选两端的节点作为根,那么子树的层数是 $O(n)$ 的,对于每个点统计其子树的答案也是 $O(n)$ 的,总时间复杂度是 $O(n^2)$。
而如果选择重心为根,那么每棵子树的大小都不超过 $\frac{n}{2}$,总的递归层数不会超过 $O(log n)$,这样就将复杂度从 $O(n^2)$ 降低到了 $O(nlogn)$。对复杂度的证明在这篇文章。
找重心的步骤
根据重心性质,我们只需要进行一次 DFS,在搜索的过程中记录以节点 $i$ 为根的子树大小 $size[i]$ 和 $i$ 的最大子树的大小 $son[i]$ ,最后选出 $son[i]$ 最小的节点即可。
注意,因为进行点分治的树一般是无根树,而我们 DFS 只能朝一个方向统计,因此计算 $son[i]$ 的时候,还要考虑其另一端的节点大小。假设当前统计的子树共有 $sz$ 个节点,那么对 $i$ 下方的所有子树统计完 $son[i]$ 后,需要额外考虑上端的节点,所以最终 $son[i] = max(son[i], sz - son[i])$.
记得在找重心的过程中,跳过已经被打标记的点和父节点。正确地运用 $vis$ 标记是点分治复杂度正确的保证。
void findRoot(int u, int p) {
size[u] = 1, son[u] = 0;
for (int i = head[u], v; i; i = e[i].next) {
if (vis[(v = e[i].v)] || v == p)
continue;
findRoot(v, u);
size[u] += size[v];
son[u] = max(son[u], size[v]);
}
son[u] = max(son[u], tot - size[u]);
if (son[u] < minv)
minv = son[u], rt = u;
}
第二步:统计答案
现在我们考虑从一个任意点开始,如何统计答案。以 POJ1741 Tree 为例题。
给定一个 $n$ 个顶点的树,每条边有一个长度 $w$。定义 $dist(u,v)$ 为点 $u, v$ 之间的最短距离。现给一个整数 $k$,问有多少对 $(u, v)$ 满足 $dist(u, v) \le k$. $(n \le 10000)$.
这道题目的数据范围显然要求我们使用 $O(n^2)$ 以下的复杂度解决,所以考虑点分治。
现在我们考虑:如何统计以点重心点 $i$ 为根的子树下,满足 $dist(u, v) \le k$ 的点对 $(u, v)$ 的个数。
要求符合条件的点个数,那么首先需要知道两点路径长度。直接暴力枚举两个点算距离当然是不可行的,时间复杂度高达 $O(n^3)$。即便使用 LCA 优化,也会达到 $O(n^2logn)$.
此时就要运用点分治的思想:每次考虑一个点,统计该点的答案。所以我们在统计某个点 $i$ 的时候,只需要考虑经过 $i$ 的路径即可。
那么首先 DFS 在 $O(n)$ 时间内,求出 $i$ 的子树下每个点 $j$ 到 $i$ 的距离 $dist(i, j)$。
void getDist(int u, int p) {
for (int i = head[u], v; i; i = e[i].next) {
if (vis[(v = e[i].v)] || v == p)
continue;
dep[v] = dep[u] + e[i].w;
getDist(v, u);
}
}
此时,假设在以 $i$ 为根的子树中,存在符合条件的点对路径 $(u,v), u \neq v$,那么可能有以下两种情况:
- $dist(u, i) + dist(i, v) \le k$ ,即在 $i$ 的两棵不同子树中各找两个点,这两个点到 $i$ 的路径之和小于等于 $k$.
- $dist(u, x) + dist(x, v) \le k$,$x$ 是 $i$ 子树中一个异于 $u, v$ 的节点。
根据最小路径的定义,因为 $x$ 是 $i$ 的一个子节点,所以路径 $(u,x,v)$ 不会经过 $i$,从而 $u, v, x$ 一定在 $i$ 的某个子树中。那么对于情况 2 的这部分答案,我们只要对 $i$ 的子树进行递归处理,就可以将其转化为某个节点的情况 1。
也就是说,我们统计的时候只要考虑计算情况 1 中的合法答案就可以了。
怎样统计呢?对于这道题目而言,我们将每次对重心点 $i$ 使用 getDist(i, 0)
求出其所有子节点到它自己的距离,存进一个数组 $d$ 中(不要忘记 $i$ 到自身的距离为 0)。要统计有多少点对满足 $dist(u, i) + dist(i, v) \le k$,也就是计算 $d$ 数组中有多少点对之和不超过 $k$。
计算答案的时候可以使用双指针移动法:首先将 $d$ 数组排序。假设当前 $d$ 数组中元素个数为 $dcnt$:
- 令 $i = 1 \to dcnt-2, j = dcnt - 1$;
- 对于每一个 $i$,判断 $d[i] + d[j]$ 是否不超过 $k$.
- 如果超过 $k$ 则将 $j$ 前移,否则当前的 $i, j$ 会贡献 $j-i$ 个答案,然后可以对 $i$ 后移继续统计。
上述方法的复杂度是 $O(dcnt)$ 的。当然,对于每一个 $i$,也可以直接二分搜索得到最后一个符合条件的 $j$.
void getDist(int u, int p) {
d[dcnt++] = dep[u];
for (int i = head[u], v; i; i = e[i].next) {
if (vis[(v = e[i].v)] || v == p)
continue;
dep[v] = dep[u] + e[i].w;
getDist(v, u);
}
}
int getAns(int u) {
int tmpans = 0;
dcnt = 0;
getDist(u, 0);
sort(d, d + dcnt);
int i = 0, j = dcnt - 1;
while (i < j) {
if (d[i] + d[j] <= k)
tmpans += j - i, i++;
else
j--;
}
return tmpans;
}
细心的读者一定发现,刚刚说到考虑情况 1 的时候,有四个字被加粗了——“合法答案”。
为什么说是合法答案?根据题意,我们发现最小距离的定义,本质上其实是两个点的 LCA 到各自的距离之和,即:
$$dist(u, v) = dist(u, lca(u, v)) + dist(v, lca(u, v))$$
但因为我们只计算了 $dist(u, i)$ 和 $dist(v, i)$,有可能 $u, v$ 的最近公共祖先是 $i$ 子树中的另一个节点 $j$ 而不是 $i$。
也就是说实际上存在 $dist(u, v) = dist(u, j) + dist(v, j) < dist(u, i) + dist(v, i)$ 的情况。因此,我们并不能得到 $dist(i, *)$ 后,就简单相加来统计答案。
怎么解决这个问题呢?
利用容斥原理去除非法路径
假设我们对当前树根节点 $i$ 统计答案,求出了一个点对 $(u, v)$ 满足 $dist(u, i) + dist(i, v) \le k$.
同时对于 $i$ 的某一个直接子节点 $j$ ,满足 $dist(u, j) + dist(j, v) \le k$.
那么显然我们不应该在统计 $i$ 的时候将点对 $(u, v)$ 计入答案中,一个原因是 $dist(u, i) + dist(i, v)$ 不满足最短距离的定义,另一个原因是等会如果我们统计 $j$ 所在的子树的时候会将点对 $(u, v)$ 重复计算。
这样一来,去掉它的方法就很显然了:
我们首先统计 $i$ 的答案,即满足 $dist(u, i) + dist(i, v) \le k$ 的点对数,记为 $ans_i$.
然后,对于 $i$ 的所有直接子节点 $j$,统计同时经过 $i,j$ 且满足条件的点对,即经过 $i,j$ 且满足 $dist(u, i) + dist(i, v) \le k$ 的点对数,记为 $ans_j$.
那么最终 $i$ 点贡献的情况 1 的总答案数为 $ans_i - \Sigma ans_j$.
也就是像这样:
solve(int u) {
ans += getAns(u);
for (auto j : son(u))
ans -= getAns(j);
}
这样我们就去除了理应在 $i$ 的子树中被统计的那部分重复答案。
注意的是,我们在统计经过 $i$ 和 $i$ 的子节点 $j$ 的时候,求的仍然是满足条件 $dist(u, i) + dist(i, v) \le k$ 的节点,而不是满足 $dist(u, j) + dist(j, v) \le k$ 的节点(后者是我们递归子节点时候才统计的)。
也就是说,这里满足的条件始终是相对 $i$ 而言的,所以容斥合并的过程,并不是一个递归的过程,只需要减去 $i$ 的所有直接子节点 $j$ 的 $ans_j$,而不需要对 $j$ 的子节点继续容斥合并。
实际上,利用容斥原理合并答案只是一个比较常用的处理方法。对于其他题目,也可以有不同的处理方法——例如,对来自每个子树的答案染色,合并时不处理两个来自同一个子树的结果即可。
第三步:分治求解
要注意的是,我们分治求解的节点,永远是整棵树的重心节点,或者是上一次递归处理的节点的子树的重心节点。所以分治求解的步骤很简单:
找重心 $rt$ → 统计重心答案 $ans_{rt}$ → 对于 $rt$ 的每一棵子树 $T_i$ 找重心 $rt_i$ → 递归分治求解 $rt_i$.
统计答案的思想在上一步我们已经解决了。所以代码我们可以很容易地写出来:
void solve(int u) {
dep[u] = 0;
ans += getAns(u);
vis[u] = 1;
for (int i = head[u], v; i; i = e[i].next) {
if (vis[(v = e[i].v)])
continue;
ans -= getAns(v);
tot = size[v];
mins = INF;
findRoot(v, 0);
solve(rt);
}
}
现在来说说,刚刚我们一直没有说到 $vis$ 标记的作用。注意到我们每一次分治的下一个点都是子树的重心,这个重心点极有可能不是当前点 $u$ 的子节点,那么对这个节点分治的时候,方向是不确定的,有可能会向后回到 $u$,所以使用 $vis$ 对处理过的点进行标记,才能保证点分治的时间复杂度和正确性。
另外,当我们统计点 $u$ 的时候要算的是其它子节点到 $u$ 的距离,而 $u$ 到自身的距离为 $0$, 所以统计的时候要初始化 $dep[u] = 0$.
总的代码会像下面这样:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
struct Edge {
int u, v, w, next;
};
const int MAXN = 10050,
INF = 0x3f3f3f3f;
int n, k, cnt = 1, dcnt = 0, rt = 0, mins = INF, tot = 0, ans = 0;
int head[MAXN], son[MAXN], size[MAXN], d[MAXN], dep[MAXN];
bool vis[MAXN];
Edge e[MAXN << 1];
template<class T> void read(T &x) {
T a = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9') {
if (ch == '-')
f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
a = a * 10 + ch - '0';
ch = getchar();
}
x = a * f;
}
void findRoot(int u, int p) {
size[u] = 1, son[u] = 0;
for (int i = head[u], v; i; i = e[i].next) {
if (vis[(v = e[i].v)] || v == p)
continue;
findRoot(v, u);
size[u] += size[v];
son[u] = max(son[u], size[v]);
}
son[u] = max(son[u], tot - son[u]);
if (son[u] < mins)
mins = son[u], rt = u;
}
void getDist(int u, int p) {
d[dcnt++] = dep[u];
for (int i = head[u], v; i; i = e[i].next) {
if (vis[(v = e[i].v)] || v == p)
continue;
dep[v] = dep[u] + e[i].w;
getDist(v, u);
}
}
int getAns(int u) {
int tmpans = 0;
dcnt = 0;
getDist(u, 0);
sort(d, d + dcnt);
int i = 0, j = dcnt - 1;
while (i < j) {
if (d[i] + d[j] <= k)
tmpans += j - i, i++;
else
j--;
}
return tmpans;
}
void solve(int u) {
dep[u] = 0;
ans += getAns(u);
vis[u] = 1;
for (int i = head[u], v; i; i = e[i].next) {
if (vis[(v = e[i].v)])
continue;
ans -= getAns(v);
tot = size[v];
mins = INF;
findRoot(v, 0);
solve(rt);
}
}
void add_edge(int u, int v, int w) {
e[cnt] = (Edge){ u, v, w, head[u] };
head[u] = cnt++;
}
int main() {
while (1) {
read(n), read(k);
if (n == 0 && k == 0)
break;
memset(head, 0, sizeof head);
memset(vis, 0, sizeof vis);
memset(size, 0, sizeof size);
memset(son, 0, sizeof son);
memset(dep, 0, sizeof dep);
cnt = 1, dcnt = 0, rt = 0, mins = INF, tot = 0, ans = 0;
for (int i = 0, u, v, w; i < n - 1; i++) {
read(u), read(v), read(w);
add_edge(u, v, w), add_edge(v, u, w);
}
tot = n;
findRoot(1, 0);
solve(rt);
printf("%d\n", ans);
}
return 0;
}
例题
计算一棵树上距离为 $k$ 的点是否存在
给定一棵有 $n$ 个点的树,$m$ 次询问树上距离为 $k$ 的点对是否存在。($n \le 10000, m \le 100$)
这道题和树上距离(路径)相关,并且 $n$ 的规模达到了 $10000$, 所以我们考虑使用点分治解决(废话那不然怎么叫点分治模板).
找重心和分治的步骤大家都差不多,关键是统计答案的步骤。首先如何统计距离为 $k$ 的点对个数,显然我们还是对于每一个重心点 $i$ 求其他点到 $i$ 的距离 $d[i]$,然后在 $d$ 中找两个来自不同子树的点 $d[u], d[v]$ 且满足 $d[u] + d[v] = k$.
虽然这道题不像上面那道题一样求小于等于的情况,不适合用双指针移动的线性方法统计;但 $O(n^2)$ 的暴力枚举法及其优化都是很显然的——很自然地想到用二分查找就可以计数了,对吧。首先我们枚举每一个 $d[i]$,在 $d$ 数组中二分查找 $k - d[i]$ 是否存在(只需要找到一个即可),用 STL 的 lower_bound
函数就可以啦。
接下来是去掉不合法的答案。可以用上面那道题的的方法;但因为这道题是要判断存在性,我们可以直接用染色法:计算 $d$ 数组的时候同时记录其来自重心点 $rt$ 的哪一个子节点 $j$,然后在处理 $d[i]$、判断 $k - d[i]$ 的存在性时,忽略那些与 $d[i]$ 来自同一个子节点的 $k - d[i]$ 即可。
最后,因为这道题要询问多个 $k$,但是询问数量 $m$ 是很小的,因此我们每次统计答案的时候就把所有的询问都扫一遍,一个个判断是否符合并标记;若对于已经存在的询问 $k$ 则可以直接跳过。
参考代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
struct Edge {
int u, v, next;
ll w;
};
struct Dist {
ll w;
int par;
bool operator < (const Dist &b) const {
return w < b.w;
}
};
const int MAXN = 10050,
MAXM = 150,
INF = 0x3f3f3f3f;
int n, m, cnt = 1, minv = INF, rt = 0, tot = 0, dcnt = 0;
bool vis[MAXN];
int head[MAXN], qry[MAXM], ans[MAXM], size[MAXN], son[MAXN];
Edge e[MAXN << 1];
Dist d[MAXN];
template<class T> void read(T &x) {
T a = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9') {
if (ch == '-')
f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9') {
a = a * 10 + ch - '0';
ch = getchar();
}
x = a * f;
}
void add_edge(int u, int v, ll w) {
e[cnt] = (Edge){ u, v, head[u], w };
head[u] = cnt++;
}
void findRoot(int u, int p) {
size[u] = 1, son[u] = 0;
for (int i = head[u], v; i; i = e[i].next) {
if (vis[(v = e[i].v)] || v == p)
continue;
findRoot(v, u);
size[u] += size[v];
son[u] = max(son[u], size[v]);
}
son[u] = max(son[u], tot - size[u]);
if (son[u] < minv)
minv = son[u], rt = u;
}
void getDist(int u, int p, int par, ll dt) {
d[dcnt++] = (Dist) { dt, par };
for (int i = head[u], v; i; i = e[i].next) {
if (vis[(v = e[i].v)] || p == v)
continue;
getDist(v, u, par, dt + e[i].w);
}
}
void solve(int cur) {
dcnt = 0;
for (int i = head[cur], v; i; i = e[i].next) {
if (vis[(v = e[i].v)])
continue;
getDist(v, cur, v, e[i].w);
}
d[dcnt++] = (Dist){ 0ll, 0 };
sort(d, d + dcnt);
for (int i = 0; i < m; i++) {
if (ans[i])
continue;
int l = 0;
while (l < dcnt && d[l].w + d[dcnt-1].w < qry[i])
l++;
while (l < dcnt && !ans[i]) {
if (qry[i] - d[l].w < d[l].w)
break;
int pos = lower_bound(d, d + dcnt, (Dist){ qry[i] - d[l].w, 0 }) - d;
while (pos < dcnt && d[pos].par == d[l].par)
pos++;
if (pos < dcnt && d[pos].w + d[l].w == qry[i])
ans[i] = 1;
l++;
}
}
}
void work(int x) {
vis[x] = 1;
solve(x);
for (int i = head[x], v; i; i = e[i].next) {
if (vis[(v = e[i].v)])
continue;
rt = 0, minv = INF, tot = size[v];
findRoot(v, 0);
work(rt);
}
}
int main() {
read(n), read(m);
for (int i = 0, u, v; i < n-1; i++) {
ll w;
read(u), read(v), read(w);
add_edge(u, v, w);
add_edge(v, u, w);
}
for (int i = 0; i < m; i++)
read(qry[i]);
tot = n;
findRoot(1, 0);
work(rt);
for (int i = 0; i < m; i++)
printf(ans[i] ? "AYE\n" : "NAY\n");
return 0;
}