跳转至

K-D 树

本文转载(或修改)自 OI-Wiki

k-D Tree(KDT , k-Dimension Tree) 是一种可以 高效处理 \(k\) 维空间信息 的数据结构。

在结点数 \(n\) 远大于 \(2^k\) 时,应用 k-D Tree 的时间效率很好。

在算法竞赛的题目中,一般有 \(k=2\)。在本页面分析时间复杂度时,将认为 \(k\) 是常数。

建树

k-D Tree 具有二叉搜索树的形态,二叉搜索树上的每个结点都对应 \(k\) 维空间内的一个点。其每个子树中的点都在一个 \(k\) 维的超长方体内,这个超长方体内的所有点也都在这个子树中。

假设我们已经知道了 \(k\) 维空间内的 \(n\) 个不同的点的坐标,要将其构建成一棵 k-D Tree,步骤如下:

  1. 若当前超长方体中只有一个点,返回这个点。

  2. 选择一个维度,将当前超长方体按照这个维度分成两个超长方体。

  3. 选择切割点:在选择的维度上选择一个点,这一维度上的值小于这个点的归入一个超长方体(左子树),其余的归入另一个超长方体(右子树)。

  4. 将选择的点作为这棵子树的根节点,递归对分出的两个超长方体构建左右子树,维护子树的信息。

为了方便理解,我们举一个 \(k=2\) 时的例子。

其构建出 k-D Tree 的形态可能是这样的:

其中树上每个结点上的坐标是选择的分割点的坐标,非叶子结点旁的 \(x\)\(y\) 是选择的切割维度。

这样的复杂度无法保证。对于 \(2,3\) 两步,我们提出两个优化:

  1. 轮流选择 \(k\) 个维度,以保证在任意连续 \(k\) 层里每个维度都被切割到。
  2. 每次在维度上选择切割点时选择该维度上的 中位数,这样可以保证每次分成的左右子树大小尽量相等。

可以发现,使用优化 \(2\) 后,构建出的 k-D Tree 的树高最多为 \(\log n+O(1)\)

现在,构建 k-D Tree 时间复杂度的瓶颈在于快速选出一个维度上的中位数,并将在该维度上的值小于该中位数的置于中位数的左边,其余置于右边。如果每次都使用 sort 函数对该维度进行排序,时间复杂度是 \(O(n\log^2 n)\) 的。事实上,单次找出 \(n\) 个元素中的中位数并将中位数置于排序后正确的位置的复杂度可以达到 \(O(n)\)

我们来回顾一下快速排序的思想。每次我们选出一个数,将小于该数的置于该数的左边,大于该数的置于该数的右边,保证该数在排好序后正确的位置上,然后递归排序左侧和右侧的值。这样的期望复杂度是 \(O(n\log n)\) 的。但是由于 k-D Tree 只要求要中位数在排序后正确的位置上,所以我们只需要递归排序包含中位数的 一侧。可以证明,这样的期望复杂度是 \(O(n)\) 的。在 algorithm 库中,有一个实现相同功能的函数 nth_element(),要找到 s[l]s[r] 之间的值按照排序规则 cmp 排序后在 s[mid] 位置上的值,并保证 s[mid] 左边的值小于 s[mid],右边的值大于 s[mid],只需写 nth_element(s+l,s+mid,s+r+1,cmp)

借助这种思想,构建 k-D Tree 时间复杂度是 \(O(n\log n)\) 的。

高维空间上的操作

在查询高维矩形区域内的所有点的一些信息时,记录每个结点子树内每一维度上的坐标的最大值和最小值。如果当前子树对应的矩形与所求矩形没有交点,则不继续搜索其子树;如果当前子树对应的矩形完全包含在所求矩形内,返回当前子树内所有点的权值和;否则,判断当前点是否在所求矩形内,更新答案并递归在左右子树中查找答案。

实现
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
int query(int p) {
  if (!p) return 0;
  bool flag{false};
  for (int k : {0, 1}) flag |= (!(l.x[k] <= t[p].L[k] && t[p].R[k] <= h.x[k]));
  if (!flag) return t[p].sum;
  for (int k : {0, 1})
    if (t[p].R[k] < l.x[k] || h.x[k] < t[p].L[k]) return 0;
  int ans{0};
  flag = false;
  for (int k : {0, 1}) flag |= (!(l.x[k] <= t[p].x[k] && t[p].x[k] <= h.x[k]));
  if (!flag) ans = t[p].v;
  return ans += query(t[p].l) + query(t[p].r);
}

复杂度分析

先考虑二维的,在查询矩形 \(R\) 时,我们将 k-D Tree 上的结点分为三类:

  1. \(R\) 无交。
  2. 完全被 \(R\) 包含。
  3. 部分被 \(R\) 包含。

显然单次查询的复杂度是第 3 类点的个数。注意到第三类点的矩形要么完全包含 \(R\),要么互不包含,而前者显然只有 \(O(h)=O(\log n)\) 个,现在我们来分析后者的个数。

首先,我们不妨令矩形的所有边偏移 \(\epsilon\),使得查询矩形不穿过已经有的任何点。这样显然是不影响矩形的查询所涵盖的点集的。

注意到互不包含的第 3 类点所对应的矩形,一定有 \(R\) 的一条边穿过之。所以我们只需要计算 \(R\) 的每条边穿过的矩形个数,即任意一条线段最多经过多少个点对应的矩形。

考虑对于某一个结点 \(u\),它有四个孙子,且它到每一个孙子都在两个维度上各进行了一次划分。经过观察可以发现,按照这种方法将一个矩形划分成四个子矩形,一条与坐标轴平行的线段最多经过两个区域,即从 \(u\) 出发的查询,最多向下进入两个孙子仍有第 3 类点(如果线段刚好与分割边界重合则不一定,但是我们偏移查询矩形边界的操作使得这种情况不存在)。

而因为建树的时候,每个点是其整个子树在当前划分维度上的中位数,所以子树大小必定减半。于是,设 \(u\) 的子树大小为 \(n\),我们能写出如下递归式:

\[ T(n)=2T(n/4)+O(1) \]

由主定理得 \(T(n)=O(\sqrt{n})\)

将递归式推广到 \(k\) 维,即 \(T(n)=2^{k-1}T(n/2^k)+O(1)\),于是 \(T(n)=O(n^{1-\frac1k})\)(将 \(k\) 视为常数)。

插入/删除

如果维护的这个 \(k\) 维点集是可变的,即可能会插入或删除一些点,此时 k-D Tree 的平衡性无法保证。由于 k-D Tree 的构造,不能支持旋转,类似与 FHQ Treap 的随机优先级也不能保证其复杂度。对此,有两种比较常见的维护方法。

Note

很多选手会使用替罪羊树结构来维护。但是注意到在刚才的复杂度分析中,要求儿子的子树大小严格减半,即树高必须为严格的 \(\log n+O(1)\),而替罪羊树只满足树高 \(O(\log n)\),故查询复杂度无法保证。

根号重构

插入的时候,先存下来要插入的点,每 \(B\) 次插入进行一次重构。

删除打个标记即可。如果要求较为严格,可以维护树内有多少个被删除了,达到 \(B\) 则重构。

修改复杂度均摊 \(O(n\log n/B)\),查询 \(O(B+n^{1-\frac1k})\),若二者数量同阶则 \(B=O(\sqrt{n\log n})\) 最优(修改 \(O(\sqrt{n\log n})\),查询 \(O(\sqrt{n\log n}+n^{1-\frac1k})\))。

二进制分组

考虑维护若干棵 \(2\) 的自然数次幂的 k-D Tree,满足这些树的大小之和为 \(n\)

插入的时候,新增一棵大小为 \(1\) 的 k-D Tree,然后不断将相同大小的树合并(直接拍扁重构)。实现的时候可以只重构一次。

容易发现需要合并的树的大小一定从 \(2^0\) 开始且指数连续。复杂度类似二进制加法,是均摊 \(O(n\log^2 n)\) 的,因为重构本身带 \(\log\)

查询的时候,直接分别在每棵树上查询,复杂度为 \(O\left(\sum_{i\geq0} (\frac n{2^i})^{1-\frac1k}\right)=O(n^{1-\frac1k})\)

例题

洛谷 P4148 简单题

在一个初始值全为 \(0\)\(n\times n\) 的二维矩阵上,进行 \(q\) 次操作,每次操作为以下两种之一:

  1. 1 x y A:将坐标 \((x,y)\) 上的数加上 \(A\)
  2. 2 x1 y1 x2 y2:输出以 \((x1,y1)\) 为左下角,\((x2,y2)\) 为右上角的矩形内(包括矩形边界)的数字和。

强制在线。内存限制 20M。保证答案及所有过程量在 int 范围内。

\(1\le n\le 500000 , 1\le q\le 200000\)

20M 的空间卡掉了所有树套树,强制在线卡掉了 CDQ 分治,只能使用 k-D Tree。

以下是二进制分组的参考代码。

实现
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
#define mid ((l + r) >> 1)
typedef long long ll;
const int MAXN = 5e5 + 10;
inline void chkmx(int &a, int b) { a = a < b ? b : a; }
inline void chkmn(int &a, int b) { a = a > b ? b : a; }
inline void chkmx(ll &a, ll b) { a = a < b ? b : a; }
inline void chkmn(ll &a, ll b) { a = a > b ? b : a; }
int lstans, tr[MAXN][2], sum[MAXN], val[MAXN], lst[MAXN][2], rst[MAXN][2],
    b[MAXN], rt[MAXN], cnt, ls[MAXN], rs[MAXN], n, pl[2], pr[2];
void upd(int x)
{
    sum[x] = sum[ls[x]] + sum[rs[x]] + val[x];
    for (int k : {0, 1})
    {
        lst[x][k] = rst[x][k] = tr[x][k];
        if (ls[x])
        {
            chkmn(lst[x][k], lst[ls[x]][k]);
            chkmx(rst[x][k], rst[ls[x]][k]);
        }
        if (rs[x])
        {
            chkmn(lst[x][k], lst[rs[x]][k]);
            chkmx(rst[x][k], rst[rs[x]][k]);
        }
    }
}
int build(int l, int r, int k)
{
    nth_element(b + l, b + mid, b + r + 1,
                [&](int x, int y)
                { return tr[x][k] < tr[y][k]; });
    int x = b[mid];
    if (l < mid) ls[x] = build(l, mid - 1, k ^ 1);
    if (mid < r) rs[x] = build(mid + 1, r, k ^ 1);
    upd(x);
    return x;
}
void append(int &x)
{
    if (!x) return;
    b[++cnt] = x;
    append(ls[x]);
    append(rs[x]);
    x = 0;
}
bool covered(int x)
{
    for (int k : {0, 1})
        if (!(pl[k] <= lst[x][k] && rst[x][k] <= pr[k])) return false;
    return true;
}
bool check(int x)
{
    for (int k : {0, 1})
        if (pr[k] < lst[x][k] || rst[x][k] < pl[k]) return false;
    return true;
}
bool valid(int x)
{
    for (int k : {0, 1})
        if (!(pl[k] <= tr[x][k] && tr[x][k] <= pr[k])) return false;
    return true;
}
int qry(int x)
{
    if (!x || !check(x)) return 0;
    if (covered(x)) return sum[x];
    int res = 0;
    if (valid(x)) res += val[x];
    return res + qry(ls[x]) + qry(rs[x]);
}
int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int n, op, x, y, a;
    cin >> n;
    n = 0;
    while (1)
    {
        cin >> op;
        switch (op)
        {
        case 1:
            cin >> x >> y >> a;
            x ^= lstans, y ^= lstans, a ^= lstans;
            val[++n] = a;
            tr[n][0] = x, tr[n][1] = y;
            b[cnt = 1] = n;
            for (int sz = 0;; sz++)
                if (!rt[sz])
                {
                    rt[sz] = build(1, cnt, 0);
                    break;
                }
                else
                    append(rt[sz]);
            break;
        case 2:
            cin >> pl[0] >> pl[1] >> pr[0] >> pr[1];
            pl[0] ^= lstans, pl[1] ^= lstans, pr[0] ^= lstans, pr[1] ^= lstans;
            lstans = 0;
            for (int i = 0; i < 25; i++) lstans += qry(rt[i]);
            cout << lstans << endl;
            break;
        case 3:
            return 0;
        }
    }
    return 0;
}

邻域查询

Warning

使用 k-D Tree 单次查询最近点的时间复杂度最坏还是 \(O(n)\) 的,但不失为一种优秀的骗分算法,使用时请注意。在这里对邻域查询的讲解仅限于加强对 k-D Tree 结构的认识。

luogu P1429 平面最近点对(加强版)

给定平面上的 \(n\) 个点 \((x_i,y_i)\),找出平面上最近两个点对之间的欧几里得距离。

\(2\le n\le 200000 , 0\le x_i,y_i\le 10^9\)

首先建出关于这 \(n\) 个点的 2-D Tree。

枚举每个结点,对于每个结点找到不等于该结点且距离最小的点,即可求出答案。每次暴力遍历 2-D Tree 上的每个结点的时间复杂度是 \(O(n)\) 的,需要剪枝。我们可以维护一个子树中的所有结点在每一维上的坐标的最小值和最大值。假设当前已经找到的最近点对的距离是 \(ans\),如果查询点到子树内所有点都包含在内的长方形的 最近 距离大于等于 \(ans\),则在这个子树内一定没有答案,搜索时不进入这个子树。

此外,还可以使用一种启发式搜索的方法,即若一个结点的两个子树都有可能包含答案,先在与查询点距离最近的一个子树中搜索答案。可以认为,查询点到子树对应的长方形的最近距离就是此题的估价函数

实现
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <iomanip>
#include <iostream>
using namespace std;
constexpr int MAXN = 200010;
int n, d[MAXN], lc[MAXN], rc[MAXN];
double ans = 2e18;
struct node
{
    double x, y;
} s[MAXN];
double L[MAXN], R[MAXN], D[MAXN], U[MAXN];
double dist(int a, int b)
{
    return (s[a].x - s[b].x) * (s[a].x - s[b].x) +
           (s[a].y - s[b].y) * (s[a].y - s[b].y);
}
bool cmp1(node a, node b) { return a.x < b.x; }
bool cmp2(node a, node b) { return a.y < b.y; }
void maintain(int x)
{
    L[x] = R[x] = s[x].x;
    D[x] = U[x] = s[x].y;
    if (lc[x])
        L[x] = min(L[x], L[lc[x]]), R[x] = max(R[x], R[lc[x]]),
        D[x] = min(D[x], D[lc[x]]), U[x] = max(U[x], U[lc[x]]);
    if (rc[x])
        L[x] = min(L[x], L[rc[x]]), R[x] = max(R[x], R[rc[x]]),
        D[x] = min(D[x], D[rc[x]]), U[x] = max(U[x], U[rc[x]]);
}
int build(int l, int r)
{
    if (l > r) return 0;
    if (l == r)
    {
        maintain(l);
        return l;
    }
    int mid = (l + r) >> 1;
    double avx = 0, avy = 0, vax = 0, vay = 0; // average variance
    for (int i = l; i <= r; i++) avx += s[i].x, avy += s[i].y;
    avx /= (double)(r - l + 1);
    avy /= (double)(r - l + 1);
    for (int i = l; i <= r; i++)
        vax += (s[i].x - avx) * (s[i].x - avx),
            vay += (s[i].y - avy) * (s[i].y - avy);
    if (vax >= vay)
        d[mid] = 1, nth_element(s + l, s + mid, s + r + 1, cmp1);
    else
        d[mid] = 2, nth_element(s + l, s + mid, s + r + 1, cmp2);
    lc[mid] = build(l, mid - 1), rc[mid] = build(mid + 1, r);
    maintain(mid);
    return mid;
}
double f(int a, int b)
{
    double ret = 0;
    if (L[b] > s[a].x) ret += (L[b] - s[a].x) * (L[b] - s[a].x);
    if (R[b] < s[a].x) ret += (s[a].x - R[b]) * (s[a].x - R[b]);
    if (D[b] > s[a].y) ret += (D[b] - s[a].y) * (D[b] - s[a].y);
    if (U[b] < s[a].y) ret += (s[a].y - U[b]) * (s[a].y - U[b]);
    return ret;
}
void query(int l, int r, int x)
{
    if (l > r) return;
    int mid = (l + r) >> 1;
    if (mid != x) ans = min(ans, dist(x, mid));
    if (l == r) return;
    double distl = f(x, lc[mid]), distr = f(x, rc[mid]);
    if (distl < ans && distr < ans)
    {
        if (distl < distr)
        {
            query(l, mid - 1, x);
            if (distr < ans) query(mid + 1, r, x);
        }
        else
        {
            query(mid + 1, r, x);
            if (distl < ans) query(l, mid - 1, x);
        }
    }
    else
    {
        if (distl < ans) query(l, mid - 1, x);
        if (distr < ans) query(mid + 1, r, x);
    }
}
int main()
{
    cin.tie(nullptr)->sync_with_stdio(false);
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> s[i].x >> s[i].y;
    build(1, n);
    for (int i = 1; i <= n; i++) query(1, n, i);
    cout << fixed << setprecision(4) << sqrt(ans) << '\n';
    return 0;
}
「CQOI2016」K 远点对

给定平面上的 \(n\) 个点 \((x_i,y_i)\),求欧几里得距离下的第 \(k\) 远无序点对之间的距离。

\(n\le 100000 , 1\le k\le 100 , 0\le x_i,y_i<2^{31}\)

和上一道例题类似,从最近点对变成了 \(k\) 远点对,估价函数改成了查询点到子树对应的长方形区域的最远距离。用一个小根堆来维护当前找到的前 \(k\) 远点对之间的距离,如果当前找到的点对距离大于堆顶,则弹出堆顶并插入这个距离,同样的,使用堆顶的距离来剪枝。

由于题目中强调的是无序点对,即交换前后两点的顺序后仍是相同的点对,则每个有序点对会被计算两次,那么读入的 \(k\) 要乘以 \(2\)

实现
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <algorithm>
#include <cstring>
#include <iostream>
#include <queue>
using namespace std;
constexpr int MAXN = 100010;
long long n, k;
priority_queue<long long, vector<long long>, greater<long long>> q;
struct node
{
    long long x, y;
} s[MAXN];
bool cmp1(node a, node b) { return a.x < b.x; }
bool cmp2(node a, node b) { return a.y < b.y; }
long long lc[MAXN], rc[MAXN], L[MAXN], R[MAXN], D[MAXN], U[MAXN];
void maintain(int x)
{
    L[x] = R[x] = s[x].x;
    D[x] = U[x] = s[x].y;
    if (lc[x])
        L[x] = min(L[x], L[lc[x]]), R[x] = max(R[x], R[lc[x]]),
        D[x] = min(D[x], D[lc[x]]), U[x] = max(U[x], U[lc[x]]);
    if (rc[x])
        L[x] = min(L[x], L[rc[x]]), R[x] = max(R[x], R[rc[x]]),
        D[x] = min(D[x], D[rc[x]]), U[x] = max(U[x], U[rc[x]]);
}
int build(int l, int r)
{
    if (l > r) return 0;
    int mid = (l + r) >> 1;
    double av1 = 0, av2 = 0, va1 = 0, va2 = 0; // average variance
    for (int i = l; i <= r; i++) av1 += s[i].x, av2 += s[i].y;
    av1 /= (r - l + 1);
    av2 /= (r - l + 1);
    for (int i = l; i <= r; i++)
        va1 += (av1 - s[i].x) * (av1 - s[i].x),
            va2 += (av2 - s[i].y) * (av2 - s[i].y);
    if (va1 > va2)
        nth_element(s + l, s + mid, s + r + 1, cmp1);
    else
        nth_element(s + l, s + mid, s + r + 1, cmp2);
    lc[mid] = build(l, mid - 1);
    rc[mid] = build(mid + 1, r);
    maintain(mid);
    return mid;
}
long long sq(long long x) { return x * x; }
long long dist(int a, int b)
{
    return max(sq(s[a].x - L[b]), sq(s[a].x - R[b])) +
           max(sq(s[a].y - D[b]), sq(s[a].y - U[b]));
}
void query(int l, int r, int x)
{
    if (l > r) return;
    int mid = (l + r) >> 1;
    long long t = sq(s[mid].x - s[x].x) + sq(s[mid].y - s[x].y);
    if (t > q.top()) q.pop(), q.push(t);
    long long distl = dist(x, lc[mid]), distr = dist(x, rc[mid]);
    if (distl > q.top() && distr > q.top())
    {
        if (distl > distr)
        {
            query(l, mid - 1, x);
            if (distr > q.top()) query(mid + 1, r, x);
        }
        else
        {
            query(mid + 1, r, x);
            if (distl > q.top()) query(l, mid - 1, x);
        }
    }
    else
    {
        if (distl > q.top()) query(l, mid - 1, x);
        if (distr > q.top()) query(mid + 1, r, x);
    }
}
int main()
{
    cin >> n >> k;
    k *= 2;
    for (int i = 1; i <= k; i++) q.push(0);
    for (int i = 1; i <= n; i++) cin >> s[i].x >> s[i].y;
    build(1, n);
    for (int i = 1; i <= n; i++) query(1, n, i);
    cout << q.top() << endl;
    return 0;
}

习题

「SDOI2010」捉迷藏

「Violet」天使玩偶/SJY 摆棋子

「国家集训队」JZPFAR

「BOI2007」Mokia 摩基亚

luogu P4475 巧克力王国

「CH 弱省胡策 R2」TATT