【学习笔记】点分治

点分治是大规模处理树上路径问题的工具。大意是找到一个点,递归统计其所有子树的答案,然后利用容斥原理或其它方式合并答案,最后得到整棵树的答案。

步骤

点分治的大体代码框架是这样的:首先对整棵树找重心点 $rt$,然后从 $rt$ 开始,向下递归求解。

首先,统计以 $rt$ 为根的树的答案,再对 $rt$ 的每棵子树 $ch_i$ 求解,统计得到以 $ch_i$ 为根的子树答案;根据容斥原理,$ans_{rt} - \Sigma _{i=0}^{chsize}ans_i$ 即为整棵树的答案。

第一步:找重心

重心的定义:在一棵树上找一个点,使得该点所有的子树中,最大的子树节点数最少,那么这个点就是这棵树的重心。

找重心的意义:点分治的过程中,要找到一个起始点统计答案,并且从这个点向子树递归。那么选择合理的起始点是很有必要的。

如果数据呈链状结构的时候,头铁地选两端的节点作为根,那么子树的层数是 $O(n)$ 的,对于每个点统计其子树的答案也是 $O(n)$ 的,总时间复杂度是 $O(n^2)$。

而如果选择重心为根,那么每棵子树的大小都不超过 $\frac{n}{2}$,总的递归层数不会超过 $O(log n)$,这样就将复杂度从 $O(n^2)$ 降低到了 $O(nlogn)$。对复杂度的证明在这篇文章

找重心的步骤

根据重心性质,我们只需要进行一次 DFS,在搜索的过程中记录以节点 $i$ 为根的子树大小 $size[i]$ 和 $i$ 的最大子树的大小 $son[i]$ ,最后选出 $son[i]$ 最小的节点即可。

注意,因为进行点分治的树一般是无根树,而我们 DFS 只能朝一个方向统计,因此计算 $son[i]$ 的时候,还要考虑其另一端的节点大小。假设当前统计的子树共有 $sz$ 个节点,那么对 $i$ 下方的所有子树统计完 $son[i]$ 后,需要额外考虑上端的节点,所以最终 $son[i] = max(son[i], sz - son[i])$.

记得在找重心的过程中,跳过已经被打标记的点和父节点。正确地运用 $vis$ 标记是点分治复杂度正确的保证。

void findRoot(int u, int p) {
    size[u] = 1, son[u] = 0;
    for (int i = head[u], v; i; i = e[i].next) {
        if (vis[(v = e[i].v)] || v == p)
            continue;
        findRoot(v, u);
        size[u] += size[v];
        son[u] = max(son[u], size[v]);
    }
    son[u] = max(son[u], tot - size[u]);
    if (son[u] < minv)
        minv = son[u], rt = u;		
}

第二步:统计答案

现在我们考虑从一个任意点开始,如何统计答案。以 POJ1741 Tree 为例题。

给定一个 $n$ 个顶点的树,每条边有一个长度 $w$。定义 $dist(u,v)$ 为点 $u, v$ 之间的最短距离。现给一个整数 $k$,问有多少对 $(u, v)$ 满足 $dist(u, v) \le k$. $(n \le 10000)$.

这道题目的数据范围显然要求我们使用 $O(n^2)$ 以下的复杂度解决,所以考虑点分治。

现在我们考虑:如何统计以点重心点 $i$ 为根的子树下,满足 $dist(u, v) \le k$ 的点对 $(u, v)$ 的个数。

要求符合条件的点个数,那么首先需要知道两点路径长度。直接暴力枚举两个点算距离当然是不可行的,时间复杂度高达 $O(n^3)$。即便使用 LCA 优化,也会达到 $O(n^2logn)$.

此时就要运用点分治的思想:每次考虑一个点,统计该点的答案。所以我们在统计某个点 $i$ 的时候,只需要考虑经过 $i$ 的路径即可。

那么首先 DFS 在 $O(n)$ 时间内,求出 $i$ 的子树下每个点 $j$ 到 $i$ 的距离 $dist(i, j)$。

void getDist(int u, int p) {
    for (int i = head[u], v; i; i = e[i].next) {
        if (vis[(v = e[i].v)] || v == p)	
            continue;
        dep[v] = dep[u] + e[i].w;
        getDist(v, u);
    }
}

此时,假设在以 $i$ 为根的子树中,存在符合条件的点对路径 $(u,v), u \neq v$,那么可能有以下两种情况:

  1. $dist(u, i) + dist(i, v) \le k$ ,即在 $i$ 的两棵不同子树中各找两个点,这两个点到 $i$ 的路径之和小于等于 $k$.
  2. $dist(u, x) + dist(x, v) \le k$,$x$ 是 $i$ 子树中一个异于 $u, v$ 的节点。

根据最小路径的定义,因为 $x$ 是 $i$ 的一个子节点,所以路径 $(u,x,v)$ 不会经过 $i$,从而 $u, v, x$ 一定在 $i$ 的某个子树中。那么对于情况 2 的这部分答案,我们只要对 $i$ 的子树进行递归处理,就可以将其转化为某个节点的情况 1。

也就是说,我们统计的时候只要考虑计算情况 1 中的合法答案就可以了。

怎样统计呢?对于这道题目而言,我们将每次对重心点 $i$ 使用 getDist(i, 0) 求出其所有子节点到它自己的距离,存进一个数组 $d$ 中(不要忘记 $i$ 到自身的距离为 0)。要统计有多少点对满足 $dist(u, i) + dist(i, v) \le k$,也就是计算 $d$ 数组中有多少点对之和不超过 $k$。

计算答案的时候可以使用双指针移动法:首先将 $d$ 数组排序。假设当前 $d$ 数组中元素个数为 $dcnt$:

  • 令 $i = 1 \to dcnt-2, j = dcnt - 1$;
  • 对于每一个 $i$,判断 $d[i] + d[j]$ 是否不超过 $k$.
  • 如果超过 $k$ 则将 $j$ 前移,否则当前的 $i, j$ 会贡献 $j-i$ 个答案,然后可以对 $i$ 后移继续统计。

上述方法的复杂度是 $O(dcnt)$ 的。当然,对于每一个 $i$,也可以直接二分搜索得到最后一个符合条件的 $j$.

void getDist(int u, int p) {
    d[dcnt++] = dep[u];
    for (int i = head[u], v; i; i = e[i].next) {
        if (vis[(v = e[i].v)] || v == p)	
            continue;
        dep[v] = dep[u] + e[i].w;
        getDist(v, u);
    }
}

int getAns(int u) {
    int tmpans = 0;
    dcnt = 0;
    getDist(u, 0);
    sort(d, d + dcnt);
    
    int i = 0, j = dcnt - 1;
    while (i < j) {
        if (d[i] + d[j] <= k)
            tmpans += j - i, i++;
        else
            j--;
    }
    return tmpans;
}

细心的读者一定发现,刚刚说到考虑情况 1 的时候,有四个字被加粗了——“合法答案”。

为什么说是合法答案?根据题意,我们发现最小距离的定义,本质上其实是两个点的 LCA 到各自的距离之和,即:

$$dist(u, v) = dist(u, lca(u, v)) + dist(v, lca(u, v))$$

但因为我们只计算了 $dist(u, i)$ 和 $dist(v, i)$,有可能 $u, v$ 的最近公共祖先是 $i$ 子树中的另一个节点 $j$ 而不是 $i$。

也就是说实际上存在 $dist(u, v) = dist(u, j) + dist(v, j) < dist(u, i) + dist(v, i)$ 的情况。因此,我们并不能得到 $dist(i, *)$ 后,就简单相加来统计答案。

怎么解决这个问题呢?

利用容斥原理去除非法路径

假设我们对当前树根节点 $i$ 统计答案,求出了一个点对 $(u, v)$ 满足 $dist(u, i) + dist(i, v) \le k$.

同时对于 $i$ 的某一个直接子节点 $j$ ,满足 $dist(u, j) + dist(j, v) \le k$.

那么显然我们不应该在统计 $i$ 的时候将点对 $(u, v)$ 计入答案中,一个原因是 $dist(u, i) + dist(i, v)$ 不满足最短距离的定义,另一个原因是等会如果我们统计 $j$ 所在的子树的时候会将点对 $(u, v)$ 重复计算。

这样一来,去掉它的方法就很显然了:

  • 我们首先统计 $i$ 的答案,即满足 $dist(u, i) + dist(i, v) \le k$ 的点对数,记为 $ans_i$.

  • 然后,对于 $i$ 的所有直接子节点 $j$,统计同时经过 $i,j$ 且满足条件的点对,即经过 $i,j$ 且满足 $dist(u, i) + dist(i, v) \le k$ 的点对数,记为 $ans_j$.

  • 那么最终 $i$ 点贡献的情况 1 的总答案数为 $ans_i - \Sigma ans_j$.

也就是像这样:

solve(int u) {
    ans += getAns(u);
    for (auto j : son(u))
        ans -= getAns(j);
}

这样我们就去除了理应在 $i$ 的子树中被统计的那部分重复答案。

注意的是,我们在统计经过 $i$ 和 $i$ 的子节点 $j$ 的时候,求的仍然是满足条件 $dist(u, i) + dist(i, v) \le k$ 的节点,而不是满足 $dist(u, j) + dist(j, v) \le k$ 的节点(后者是我们递归子节点时候才统计的)。

也就是说,这里满足的条件始终是相对 $i$ 而言的,所以容斥合并的过程,并不是一个递归的过程,只需要减去 $i$ 的所有直接子节点 $j$ 的 $ans_j$,而不需要对 $j$ 的子节点继续容斥合并。

实际上,利用容斥原理合并答案只是一个比较常用的处理方法。对于其他题目,也可以有不同的处理方法——例如,对来自每个子树的答案染色,合并时不处理两个来自同一个子树的结果即可。

第三步:分治求解

要注意的是,我们分治求解的节点,永远是整棵树的重心节点,或者是上一次递归处理的节点的子树的重心节点。所以分治求解的步骤很简单:

找重心 $rt$ → 统计重心答案 $ans_{rt}$ → 对于 $rt$ 的每一棵子树 $T_i$ 找重心 $rt_i$ → 递归分治求解 $rt_i$.

统计答案的思想在上一步我们已经解决了。所以代码我们可以很容易地写出来:

void solve(int u) {
    dep[u] = 0;
    ans += getAns(u);
    vis[u] = 1;
    for (int i = head[u], v; i; i = e[i].next) {
        if (vis[(v = e[i].v)])
            continue;
        ans -= getAns(v);
        tot = size[v];
        mins = INF;
        findRoot(v, 0);
        solve(rt);
    }
}

现在来说说,刚刚我们一直没有说到 $vis$ 标记的作用。注意到我们每一次分治的下一个点都是子树的重心,这个重心点极有可能不是当前点 $u$ 的子节点,那么对这个节点分治的时候,方向是不确定的,有可能会向后回到 $u$,所以使用 $vis$ 对处理过的点进行标记,才能保证点分治的时间复杂度和正确性。

另外,当我们统计点 $u$ 的时候要算的是其它子节点到 $u$ 的距离,而 $u$ 到自身的距离为 $0$, 所以统计的时候要初始化 $dep[u] = 0$.

总的代码会像下面这样:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

struct Edge {
    int u, v, w, next;
};

const int MAXN = 10050,
    INF = 0x3f3f3f3f;

int n, k, cnt = 1, dcnt = 0, rt = 0, mins = INF, tot = 0, ans = 0;

int head[MAXN], son[MAXN], size[MAXN], d[MAXN], dep[MAXN];
bool vis[MAXN];
Edge e[MAXN << 1];

template<class T> void read(T &x) {
    T a = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-')
            f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        a = a * 10 + ch - '0';
        ch = getchar();
    }
    x = a * f;
}

void findRoot(int u, int p) {
    size[u] = 1, son[u] = 0;
    for (int i = head[u], v; i; i = e[i].next) {
        if (vis[(v = e[i].v)] || v == p)
            continue;
        findRoot(v, u);
        size[u] += size[v];
        son[u] = max(son[u], size[v]);
    }
    son[u] = max(son[u], tot - son[u]);
    if (son[u] < mins)
        mins = son[u], rt = u;
}

void getDist(int u, int p) {
    d[dcnt++] = dep[u];
    for (int i = head[u], v; i; i = e[i].next) {
        if (vis[(v = e[i].v)] || v == p)	
            continue;
        dep[v] = dep[u] + e[i].w;
        getDist(v, u);
    }
}

int getAns(int u) {
    int tmpans = 0;
    dcnt = 0;
    getDist(u, 0);
    sort(d, d + dcnt);
    
    int i = 0, j = dcnt - 1;
    while (i < j) {
        if (d[i] + d[j] <= k)
            tmpans += j - i, i++;
        else
            j--;
    }
    return tmpans;
}

void solve(int u) {
    dep[u] = 0;
    ans += getAns(u);
    vis[u] = 1;
    for (int i = head[u], v; i; i = e[i].next) {
        if (vis[(v = e[i].v)])
            continue;
        ans -= getAns(v);
        tot = size[v];
        mins = INF;
        findRoot(v, 0);
        solve(rt);
    }
}

void add_edge(int u, int v, int w) {
    e[cnt] = (Edge){ u, v, w, head[u] };
    head[u] = cnt++;
}

int main() {
    while (1) {
        read(n), read(k);
        if (n == 0 && k == 0)	
            break;
        memset(head, 0, sizeof head);
        memset(vis, 0, sizeof vis);
        memset(size, 0, sizeof size);
        memset(son, 0, sizeof son);
        memset(dep, 0, sizeof dep);
        
        cnt = 1, dcnt = 0, rt = 0, mins = INF, tot = 0, ans = 0;
        
        for (int i = 0, u, v, w; i < n - 1; i++) {
            read(u), read(v), read(w);
            add_edge(u, v, w), add_edge(v, u, w);
        }
        
        tot = n;
        findRoot(1, 0);
        solve(rt);
        
        printf("%d\n", ans);
    }
    
    return 0;
}

例题

计算一棵树上距离为 $k$ 的点是否存在

luogu P3806 【模板】点分治1

给定一棵有 $n$ 个点的树,$m$ 次询问树上距离为 $k$ 的点对是否存在。($n \le 10000, m \le 100$)

这道题和树上距离(路径)相关,并且 $n$ 的规模达到了 $10000$, 所以我们考虑使用点分治解决(废话那不然怎么叫点分治模板).

找重心和分治的步骤大家都差不多,关键是统计答案的步骤。首先如何统计距离为 $k$ 的点对个数,显然我们还是对于每一个重心点 $i$ 求其他点到 $i$ 的距离 $d[i]$,然后在 $d$ 中找两个来自不同子树的点 $d[u], d[v]$ 且满足 $d[u] + d[v] = k$.

虽然这道题不像上面那道题一样求小于等于的情况,不适合用双指针移动的线性方法统计;但 $O(n^2)$ 的暴力枚举法及其优化都是很显然的——很自然地想到用二分查找就可以计数了,对吧。首先我们枚举每一个 $d[i]$,在 $d$ 数组中二分查找 $k - d[i]$ 是否存在(只需要找到一个即可),用 STL 的 lower_bound 函数就可以啦。

接下来是去掉不合法的答案。可以用上面那道题的的方法;但因为这道题是要判断存在性,我们可以直接用染色法:计算 $d$ 数组的时候同时记录其来自重心点 $rt$ 的哪一个子节点 $j$,然后在处理 $d[i]$、判断 $k - d[i]$ 的存在性时,忽略那些与 $d[i]$ 来自同一个子节点的 $k - d[i]$ 即可。

最后,因为这道题要询问多个 $k$,但是询问数量 $m$ 是很小的,因此我们每次统计答案的时候就把所有的询问都扫一遍,一个个判断是否符合并标记;若对于已经存在的询问 $k$ 则可以直接跳过。

参考代码:

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;

struct Edge {
    int u, v, next;
    ll w;
};

struct Dist {
    ll w;
    int par;
    
    bool operator < (const Dist &b) const {
        return w < b.w;
    }
};

const int MAXN = 10050,
    MAXM = 150,
    INF = 0x3f3f3f3f;
    
int n, m, cnt = 1, minv = INF, rt = 0, tot = 0, dcnt = 0;

bool vis[MAXN];
int head[MAXN], qry[MAXM], ans[MAXM], size[MAXN], son[MAXN];
Edge e[MAXN << 1];
Dist d[MAXN];

template<class T> void read(T &x) {
    T a = 0, f = 1;
    char ch = getchar();
    while (ch < '0' || ch > '9') {
        if (ch == '-')
            f = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9') {
        a = a * 10 + ch - '0';
        ch = getchar();
    }
    x = a * f;
}

void add_edge(int u, int v, ll w) {
    e[cnt] = (Edge){ u, v, head[u], w };
    head[u] = cnt++;
}

void findRoot(int u, int p) {
    size[u] = 1, son[u] = 0;
    for (int i = head[u], v; i; i = e[i].next) {
        if (vis[(v = e[i].v)] || v == p)
            continue;
        findRoot(v, u);
        size[u] += size[v];
        son[u] = max(son[u], size[v]);
    }
    son[u] = max(son[u], tot - size[u]);
    if (son[u] < minv)
        minv = son[u], rt = u;		
}

void getDist(int u, int p, int par, ll dt) {
    d[dcnt++] = (Dist) { dt, par };
    for (int i = head[u], v; i; i = e[i].next) {
        if (vis[(v = e[i].v)] || p == v)
            continue;
        getDist(v, u, par, dt + e[i].w);
    }
}

void solve(int cur) {
    dcnt = 0;
    for (int i = head[cur], v; i; i = e[i].next) {
        if (vis[(v = e[i].v)])
            continue;
        getDist(v, cur, v, e[i].w);
    }
    d[dcnt++] = (Dist){ 0ll, 0 };
    sort(d, d + dcnt);
    
    for (int i = 0; i < m; i++) {
        if (ans[i])
            continue;
        int l = 0;
        while (l < dcnt && d[l].w + d[dcnt-1].w < qry[i])
            l++;
        while (l < dcnt && !ans[i]) {
            if (qry[i] - d[l].w < d[l].w)
                break;
            int pos = lower_bound(d, d + dcnt, (Dist){ qry[i] - d[l].w, 0 }) - d;
            while (pos < dcnt && d[pos].par == d[l].par)
                pos++;
            if (pos < dcnt && d[pos].w + d[l].w == qry[i])
                ans[i] = 1;
            l++;
        }
    }
}

void work(int x) {
    vis[x] = 1;
    solve(x);
    for (int i = head[x], v; i; i = e[i].next) {
        if (vis[(v = e[i].v)])
            continue;
        rt = 0, minv = INF, tot = size[v];
        findRoot(v, 0);
        work(rt);
    }
}

int main() {
    read(n), read(m);
    for (int i = 0, u, v; i < n-1; i++) {
        ll w;
        read(u), read(v), read(w);
        add_edge(u, v, w);
        add_edge(v, u, w);
    }
    
    for (int i = 0; i < m; i++)
        read(qry[i]);

    tot = n;
    findRoot(1, 0);
    work(rt);
    
    for (int i = 0; i < m; i++)
        printf(ans[i] ? "AYE\n" : "NAY\n");
    return 0;
}