跳转至

历史和线段树

历史和线段树就是能在 \(O(n\log n)\) 中查询过去 \(q\) 个版本某个区间的和的总和。

形式化的说,有一个数组 \(a\) 和一个辅助数组 \(b\),每一次(广义)更新操作都会执行 \(a:[x,y] \rightarrow b:[x,y]\),查询 \(k\) 个版本后 \(b:[x,y]\) 的值(即 \(\sum\limits_{k}\sum\limits_{i = x}^y a_i\) 的值)。

矩阵法求解历史和

如何在 \(O(n\log n)\) 的时间内解决上述问题呢?

我们考虑使用线段树和矩阵。

LOJ193 【模板】线段树历史和

这是一道模板题。

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:

  1. 区间加一个数;
  2. 查询区间的历史和;

历史和定义为数列 \(h_i\) 的区间和:初始 \(h_i=a_i\),在每次操作(修改或查询,具体可参考样例解释)完成后,对所有 \(h_i \leftarrow h_i+a_i\)

以 LOJ193 为例,我们要求最朴素的历史和。

我们可以用线段树维护矩阵,其中矩阵为:

\[ \begin{bmatrix} his \\ sum \\ len \end{bmatrix} \]

其中 \(his\) 为历史和,\(sum\) 为区间和,\(len\) 为区间长度。

其实就是用矩阵打包线段树上要维护的所有变量。

对于叶子节点,\(len = 1\) , \(sum = a_i\)\(his = 0\);对于非叶子节点,\(tag = \begin{bmatrix} 1 & 0 & 0\\ 0 & 1 & 0\\ 0 & 0 & 1\\ \end{bmatrix}\)

然后是对线段树进行区间矩阵乘操作:

节点的合并(例节点 \(a + b \rightarrow c\)):

\[ \begin{bmatrix} his_a \\ sum_a \\ len_a \end{bmatrix} + \begin{bmatrix} his_b \\ sum_b \\ len_b \end{bmatrix} = \begin{bmatrix} his_c \\ sum_c \\ len_c \end{bmatrix} \]

区间加 \(d\) 操作为:

\[ \begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & d \\ 0 & 0 & 1 \\ \end{bmatrix} \times \begin{bmatrix} his \\ sum \\ len \end{bmatrix} = \begin{bmatrix} his \\ sum + d \times len \\ len \end{bmatrix} \]

区间历史和更新操作为:

\[ \begin{bmatrix} 1 & 1 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \\ \end{bmatrix} \times \begin{bmatrix} his \\ sum \\ len \end{bmatrix} = \begin{bmatrix} his + sum \\ sum \\ len \end{bmatrix} \]

我们将矩阵按线段树的方式下放到指定区间即可。

我们每次进行区间加 \(d\) 时,对全局进行历史和更新操作

最后查询区间矩阵的历史和,只需按线段树的方式求区间矩阵的 \(\sum his\) 即可。

这样你就成功完成了此题!

实现
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#include <bits/stdc++.h>
using namespace std;
#define int long long
int rd()
{
    int x = 0, w = 1;
    char ch = 0;
    while (ch < '0' || ch > '9')
    {
        if (ch == '-') w = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
    {
        x = x * 10 + (ch - '0');
        ch = getchar();
    }
    return x * w;
}
void wt(int x)
{
    static int sta[35];
    int f = 1;
    if (x < 0) f = -1, x *= f;
    int top = 0;
    do
    {
        sta[top++] = x % 10, x /= 10;
    } while (x);
    if (f == -1) putchar('-');
    while (top) putchar(sta[--top] + 48);
}
template <int N, int M, class T = long long>
struct matrix
{
    int m[N][M];
    matrix() { memset(m, 0, sizeof(m)); }
    void init()
    {
        for (int i = 0; i < N; i++) m[i][i] = 1;
    }
    friend bool operator!=(matrix<N, M> x, matrix<N, M> y)
    {
        for (int i = 0; i < N; i++)
            for (int j = 0; j < M; j++)
                if (x[i][j] != y[i][j]) return true;
        return false;
    }
    int *operator[](const int pos) { return m[pos]; }
    void print(string s)
    {
        cout << '\n';
        string t = "test for " + s + "  matrix:";
        cout << t << '\n';
        for (int i = 0; i < N; i++)
            for (int j = 0; j < M; j++) cout << m[i][j] << " \n"[j == M - 1];
        cout << '\n';
    }
};
template <int N, int M, int R, class T = long long>
matrix<N, R, T> operator*(matrix<N, M, T> a, matrix<M, R, T> b)
{
    matrix<N, R, T> c;
    for (int i = 0; i < N; i++)
        for (int j = 0; j < M; j++)
            for (int k = 0; k < R; k++) c[i][k] = c[i][k] + a[i][j] * b[j][k];
    return c;
}
template <int N, int M, class T = long long>
matrix<N, M, T> operator+(matrix<N, M, T> a, matrix<N, M, T> b)
{
    for (int i = 0; i < N; i++)
        for (int j = 0; j < M; j++) a[i][j] += b[i][j];
    return a;
}
const int N = 1e5 + 5;
int n, m, a[N];
namespace sgt
{
    matrix<3, 1> h[N << 2];
    matrix<3, 3> tag[N << 2];
#define ls (p << 1)
#define rs (ls | 1)
#define mid ((pl + pr) >> 1)
    void push_up(int p) { h[p] = h[ls] + h[rs]; }
    void addtag(int p, matrix<3, 3> c)
    {
        h[p] = c * h[p];
        tag[p] = c * tag[p];
    }
    void push_down(int p)
    {
        matrix<3, 3> c;
        c.init();
        if (tag[p] != c)
        {
            addtag(ls, tag[p]);
            addtag(rs, tag[p]);
            tag[p] = c;
        }
    }
    void build(int p, int pl, int pr)
    {
        matrix<3, 3> c;
        c.init();
        tag[p] = c;
        if (pl == pr)
        {
            h[p][0][0] = h[p][1][0] = a[pl];
            h[p][2][0] = 1;
            return;
        }
        build(ls, pl, mid);
        build(rs, mid + 1, pr);
        push_up(p);
    }
    void update(int p, int pl, int pr, int l, int r, matrix<3, 3> v)
    {
        if (l <= pl && pr <= r)
        {
            addtag(p, v);
            return;
        }
        push_down(p);
        if (l <= mid) update(ls, pl, mid, l, r, v);
        if (r > mid) update(rs, mid + 1, pr, l, r, v);
        push_up(p);
    }
    int query(int p, int pl, int pr, int l, int r)
    {
        if (l <= pl && pr <= r) return h[p][0][0];
        push_down(p);
        int ans = 0;
        if (l <= mid) ans += query(ls, pl, mid, l, r);
        if (r > mid) ans += query(rs, mid + 1, pr, l, r);
        return ans;
    }
} // namespace sgt
signed main()
{
    n = rd(), m = rd();
    for (int i = 1; i <= n; i++) a[i] = rd();
    sgt::build(1, 1, n);
    auto upd = [&]() -> void
    {
        int l = rd(), r = rd(), x = rd();
        matrix<3, 3> c;
        c.init();
        c[1][2] = x;
        sgt::update(1, 1, n, l, r, c);
    };
    auto qry = [&]() -> void
    {
        int l = rd(), r = rd();
        wt(sgt::query(1, 1, n, l, r));
        putchar('\n');
    };
    while (m--)
    {
        int opt = rd();
        switch (opt)
        {
        case 1:
            upd();
            break;
        case 2:
            qry();
            break;
        default:
            puts("Error");
            exit(0);
            break;
        }
        matrix<3, 3> v;
        v.init();
        v[0][1] = 1;
        sgt::update(1, 1, n, 1, n, v);
    }
    return 0;
}

通过记录:accept?

进一步优化

我们的矩阵乘法要维护两个 \(3 \times 3\) 矩阵相乘的结果,这带来的结果是常数来到了惊人的 \(27\),然而这是无法接受的!

这时聪明的奶龙就发现了,矩阵的好多地方是不变的

我们可以用下面的代码来探究到底哪些矩阵元素永远不会变:

探究随机矩阵乘所固定的元素
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <bits/stdc++.h>
using namespace std;
int rd()
{
    int x = 0, w = 1;
    char ch = 0;
    while (ch < '0' || ch > '9')
    {
        if (ch == '-') w = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
    {
        x = x * 10 + (ch - '0');
        ch = getchar();
    }
    return x * w;
}
void wt(int x)
{
    static int sta[35];
    int f = 1;
    if (x < 0) f = -1, x *= f;
    int top = 0;
    do
    {
        sta[top++] = x % 10, x /= 10;
    } while (x);
    if (f == -1) putchar('-');
    while (top) putchar(sta[--top] + 48);
}
template <int N, int M, class T = long long>
struct matrix
{
    int m[N][M];
    matrix() { memset(m, 0, sizeof(m)); }
    void init()
    {
        for (int i = 0; i < N; i++) m[i][i] = 1;
    }
    friend bool operator!=(matrix<N, M> x, matrix<N, M> y)
    {
        for (int i = 0; i < N; i++)
            for (int j = 0; j < M; j++)
                if (x[i][j] != y[i][j]) return true;
        return false;
    }
    int *operator[](const int pos) { return m[pos]; }
    void print(string s)
    {
        cout << '\n';
        string t = "test for " + s + "  matrix:";
        cout << t << '\n';
        for (int i = 0; i < N; i++)
            for (int j = 0; j < M; j++) cout << m[i][j] << " \n"[j == M - 1];
        cout << '\n';
    }
};
template <int N, int M, int R, class T = long long>
matrix<N, R, T> operator*(matrix<N, M, T> a, matrix<M, R, T> b)
{
    matrix<N, R, T> c;
    for (int i = 0; i < N; i++)
        for (int j = 0; j < M; j++)
            for (int k = 0; k < R; k++) c[i][k] = c[i][k] + a[i][j] * b[j][k];
    return c;
}
template <int N, int M, class T = long long>
matrix<N, M, T> operator+(matrix<N, M, T> a, matrix<N, M, T> b)
{
    for (int i = 0; i < N; i++)
        for (int j = 0; j < M; j++) a[i][j] += b[i][j];
    return a;
}
template <int N, class T = long long>
matrix<N, N, T> qpow(matrix<N, N, T> x, int k)
{
    matrix<N, N, T> re;
    re.init();
    while (k)
    {
        if (k & 1) re = re * x;
        x = x * x;
        k >>= 1;
    }
    return re;
}
matrix<3, 3> re, b;
signed main()
{
    re.init();
    while (1)
    {
        int c = rd();
        if (c == 0) return 0;
        if (c == 1)
        {
            b.init();
            int x = rd();
            b[1][2] = x;
            re = b * re;
            re.print("result:");
        }
        else if (c == 2)
        {
            b.init();
            b[0][1] = 1;
            re = b * re;
            re.print("result:");
        }
    }
    return 0;
}

我们会惊讶的发现,实际上矩阵中只有四个位置是在变化的:

\[ \begin{bmatrix} 1 & a & b \\ 0 & c & d \\ 0 & 0 & 1 \\ \end{bmatrix} \]

那么,我们可以通过手摸矩阵来达到 \(3 \sim 4\) 的复杂度常数!

实现
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#include <bits/stdc++.h>
using namespace std;
int rd()
{
    int x = 0, w = 1;
    char ch = 0;
    while (ch < '0' || ch > '9')
    {
        if (ch == '-') w = -1;
        ch = getchar();
    }
    while (ch >= '0' && ch <= '9')
    {
        x = x * 10 + (ch - '0');
        ch = getchar();
    }
    return x * w;
}
void wt(int x)
{
    static int sta[35];
    int f = 1;
    if (x < 0) f = -1, x *= f;
    int top = 0;
    do
    {
        sta[top++] = x % 10, x /= 10;
    } while (x);
    if (f == -1) putchar('-');
    while (top) putchar(sta[--top] + 48);
}
struct tag
{
    int x[7];
    void init()
    {
        x[1] = x[4] = x[6] = 1;
        x[2] = x[3] = x[5] = 0;
    }
    int &operator[](const int pos) { return x[pos]; }
    friend tag operator*(tag &A, tag &B)
    {
        tag c;
        c.init();
        c[2] = A[2] + B[2];
        c[3] = B[3] + A[2] * B[5] + A[3];
        c[5] = B[5] + A[5];
        return c;
    }
    friend bool operator!=(tag A, tag B)
    {
        for (int i = 0; i < 7; i++)
            if (A[i] != B[i]) return true;
        return false;
    }
    void print(string s)
    {
        cout << "test for " << s << "     matrix\n";
        cout << x[1] << ' ' << x[2] << ' ' << x[3] << '\n';
        cout << 0 << ' ' << x[4] << ' ' << x[5] << '\n';
        cout << 0 << ' ' << 0 << ' ' << x[6] << '\n';
    }
};
struct vet
{
    int y[4];
    void init() { y[1] = y[2] = y[3] = 0; }
    int &operator[](const int pos) { return y[pos]; }
    friend vet operator+(vet a, vet b)
    {
        vet c;
        c.init();
        c[1] = a[1] + b[1];
        c[2] = a[2] + b[2];
        c[3] = a[3] + b[3];
        return c;
    }
    void print(string s)
    {
        cout << '\n';
        cout << "test for " << s << "     vector\n";
        cout << y[1] << '\n';
        cout << y[2] << '\n';
        cout << y[3] << '\n';
        cout << '\n';
    }
};
vet operator*(tag A, vet B)
{
    vet c;
    c.init();
    c[1] = B[1] + B[2] * A[2] + B[3] * A[3];
    c[2] = B[2] + A[5] * B[3];
    c[3] = B[3];
    return c;
}
const int N = 1e5 + 5;
int n, m, a[N];
namespace sgt
{
#define ls (p << 1)
#define rs (ls | 1)
#define mid ((pl + pr) >> 1)
    tag T[N << 2];
    vet t[N << 2];
    void push_up(int p) { t[p] = t[ls] + t[rs]; }
    void addtag(int p, tag x)
    {
        T[p] = x * T[p];
        t[p] = x * t[p];
    }
    void push_down(int p)
    {
        tag c;
        c.init();
        if (T[p] != c)
        {
            addtag(ls, T[p]);
            addtag(rs, T[p]);
            T[p].init();
        }
    }
    void build(int p, int pl, int pr)
    {
        T[p].init();
        if (pl == pr)
        {
            t[p][2] = t[p][1] = a[pl];
            t[p][3] = 1;
            return;
        }
        build(ls, pl, mid);
        build(rs, mid + 1, pr);
        push_up(p);
    }
    void update(int p, int pl, int pr, int l, int r, tag x)
    {
        if (l <= pl && pr <= r)
        {
            addtag(p, x);
            // t[p].print("upd");
            // T[p].print("upd");
            return;
        }
        push_down(p);
        if (l <= mid) update(ls, pl, mid, l, r, x);
        if (r > mid) update(rs, mid + 1, pr, l, r, x);
        push_up(p);
    }
    int query(int p, int pl, int pr, int l, int r)
    {
        if (l <= pl && pr <= r) return t[p][1];
        push_down(p);
        if (r <= mid)
            return query(ls, pl, mid, l, r);
        else if (l > mid)
            return query(rs, mid + 1, pr, l, r);
        else
            return query(ls, pl, mid, l, r) + query(rs, mid + 1, pr, l, r);
    }
} // namespace sgt
signed main()
{
    n = rd(), m = rd();
    for (int i = 1; i <= n; i++) a[i] = rd();
    sgt::build(1, 1, n);
    while (m--)
    {
        int opt = rd();
        if (opt == 1)
        {
            int l = rd(), r = rd(), x = rd();
            tag c;
            c.init();
            c[5] = x;
            sgt::update(1, 1, n, l, r, c);
        }
        else
        {
            int l = rd(), r = rd();
            wt(sgt::query(1, 1, n, l, r));
            putchar('\n');
        }
        tag c;
        c.init();
        c[2] = 1;
        sgt::update(1, 1, n, 1, n, c);
    }
    return 0;
}

通过记录:accept!

可以看到区别还是很大的!

推理法求历史和

待后人补充!

值得注意的事情

矩阵所维护的元素所执行的操作无非加减乘除,这是矩阵的作为线性代数的性质导致的。

也就是说,历史和线段树只能用来维护具有线性关系的元素!

所有对于一类历史和线段树问题,思路都是尝试转换成一系列线性关系的操作。

习题

codeforces 1824D: LuoTianyi and the Function

codeforces 1824D 题解

NOIP2022 比赛