跳转至

长链剖分优化 DP

什么是长链剖分?

长链剖分本质上就是区别于 重链剖分 的另外一种链剖分方式。

定义 重子节点 表示其子节点中子树深度最大的子结点。如果有多个子树最大的子结点,取其一。如果没有子节点,就无重子节点。

定义 轻子节点 表示剩余的子结点。

从这个结点到重子节点的边为 重边

到其他轻子节点的边为 轻边

若干条首尾衔接的重边构成 重链

把落单的结点也当作重链,那么整棵树就被剖分成若干条重链。

优化 DP

这个优化来源于 树上启发式合并

一般情况下可以使用长链剖分来优化的 DP 会有一维状态为深度维。

我们可以考虑使用长链剖分优化树上 DP。

具体的,我们每个节点的状态直接继承其重儿子的节点状态,同时将轻儿子的 DP 状态暴力合并。

例题

CF1009F Dominant Indices

有一棵 \(n\) 个节点的树,对其每个节点求以其为根的子树中深度的众数的值。 数据范围: \(n \leq 10^6\)

我们设 \(f_{i,j}\) 表示在子树 \(i\) 内,和 \(i\) 距离为 \(j\) 的点数。

直接暴力转移时间复杂度为 \(O(n^2)\)

我们考虑每次转移我们直接继承重儿子的 DP 数组和答案,并且考虑在此基础上进行更新。

首先我们需要将重儿子的 DP 数组前面插入一个元素 \(1\),这代表着当前节点。

然后我们将所有轻儿子的 DP 数组暴力和当前节点的 DP 数组合并。

注意到因为轻儿子的 DP 数组长度为轻儿子所在重链长度,而所有重链长度和为 \(n\)

也就是说,我们直接暴力合并轻儿子的总时间复杂度为 \(O(n)\)

注意,一般情况下 DP 数组的内存分配为一条重链整体分配内存,链上不同的节点有不同的首位置指针。

DP 数组的长度我们可以根据子树最深节点算出。

实现
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
#define int ll
typedef long long ll;
typedef const int cint;
constexpr int MAXN = 1e6 + 10;
int n, buf[MAXN], *now = buf, *f[MAXN], *g[MAXN], dep[MAXN], son[MAXN], ans[MAXN];
vector<int> e[MAXN];
void dfs(cint u, cint fa)
{
    for (cint v : e[u])
        if (v ^ fa)
        {
            dfs(v, u);
            if (dep[v] > dep[son[u]]) son[u] = v;
        }
    dep[u] = dep[son[u]] + 1;
}
void dp(cint u, cint fa)
{
    f[u][0] = 1;
    if (son[u]) f[son[u]] = f[u] + 1, dp(son[u], u), ans[u] = ans[son[u]] + 1;
    for (cint v : e[u])
        if (v ^ fa && v ^ son[u])
        {
            f[v] = now, now += dep[v];
            dp(v, u);
            for (int i = 1; i <= dep[v]; i++)
            {
                f[u][i] += f[v][i - 1];
                if (f[u][i] > f[u][ans[u]] || (f[u][i] == f[u][ans[u]] && i < ans[u])) ans[u] = i;
            }
        }
    if (f[u][ans[u]] == 1) ans[u] = 0;
}
signed main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n;
    for (int u, v, i = 1; i < n; i++) cin >> u >> v, e[u].emplace_back(v), e[v].emplace_back(u);
    dfs(1, 0), f[1] = now, now += dep[1], dp(1, 0);
    for (int i = 1; i <= n; i++) cout << ans[i] << endl;
    return 0;
}
P5904 [POI 2014] HOT-Hotels 加强版

给出一棵有 \(n\) 个点的树,求有多少组无序点 \((i,j,k)\) 满足 \(i,j,k\) 两两之间的距离都相等。

数据范围: \(n \leq 10^5\)

\(f_{i,j}\) 为满足 \(x\)\(i\) 的子树中且 \(d(x,i)-j\)\(x\) 的个数,\(g_{i,j}\) 为满足 \(x,y\)\(i\) 的子树中且 \(d(\text{lca}(x,y),x)=d(\text{lca}(x,y),y)=d(\text{lca}(x,y),i)+j\) 的无序数对 \((x,y)\) 的个数。

有转移:

\[ \begin{aligned} ans & \leftarrow g_{i,0} \\ ans & \leftarrow \sum_{x,y \in \text{son}(i), x \neq y} f_{x,j-1} \times g_{y,j+1} \\ g_{i,j} & \leftarrow \sum_{x,y \in \text{son}(i), x \neq y} f_{x,j-1} \times f_{y,j-1} \\ g_{i,j} & \leftarrow \sum_{x \in \text{son}(i)} g_{x,j+1} \\ f_{i,j} & \leftarrow \sum_{x \in \text{son}(i)} f_{x,j-1} \end{aligned} \]

显然这可以利用前缀和,或者说是所有儿子「向 \(i\) 合并」,做到 \(O(n)\) 转移,总时间复杂度 \(O(n^2)\)

注意到这里的信息都是以深度为下标,那么可以利用长链剖分将复杂度降为均摊 \(O(1)\),总时间复杂度即可降为 \(O(n)\)

实现
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
#define int ll
typedef long long ll;
typedef const int cint;
constexpr int MAXN = 1e5 + 10;
int n, dep[MAXN], son[MAXN];
vector<int> e[MAXN];
ll *f[MAXN], *g[MAXN], p[MAXN << 2], *idx = p, ans;
void dfs(cint u, cint fa)
{
    for (cint v : e[u])
        if (v ^ fa)
        {
            dfs(v, u);
            if (dep[v] > dep[son[u]]) son[u] = v;
        }
    dep[u] = dep[son[u]] + 1;
}
void dp(cint u, cint fa)
{
    if (son[u]) f[son[u]] = f[u] + 1, g[son[u]] = g[u] - 1, dp(son[u], u);
    f[u][0] = 1, ans += g[u][0];
    for (cint v : e[u])
        if (v ^ fa && v ^ son[u])
        {
            f[v] = idx, idx += dep[v] << 1, g[v] = idx, idx += dep[v] << 1;
            dp(v, u);
            for (int i = 0; i < dep[v]; i++)
            {
                if (i) ans += f[u][i - 1] * g[v][i];
                ans += g[u][i + 1] * f[v][i];
            }
            for (int i = 0; i < dep[v]; i++)
            {
                g[u][i + 1] += f[u][i + 1] * f[v][i];
                if (i) g[u][i - 1] += g[v][i];
                f[u][i + 1] += f[v][i];
            }
        }
}
signed main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n;
    for (int u, v, i = 1; i < n; i++) cin >> u >> v, e[u].emplace_back(v), e[v].emplace_back(u);
    dfs(1, 0), f[1] = idx, idx += dep[1] << 1, g[1] = idx, idx += dep[1] << 1, dp(1, 0), cout << ans;
    return 0;
}