线段树优化 DP

一些动态规划问题的转移可以转化为动态的区间修改与查询,这时便可以利用线段树来加速了

下面是一些例子

CF833B The Bakery

将一个长度为 \(n\) 的序列分为 \(k\) 段,使得总价值最大。

一段区间的价值表示为区间内不同数字的个数。

数据范围: \(n \leq 35000, k \leq 50\)

我们记 \(w(i, j)\)\([l, r]\) 内不同数字的个数,那么设 \(f_{i,k}\) 为将前 \(i\) 个数分成 \(j\) 段所能得到的最大总价值

那么容易得出转移方程

\[f_{i,k + 1} = \max_{0 \le j < i}\{f_{j,k} + w(j + 1, i)\}\]

我们发现 \(w(i,j)\) 很难在短时间内重复计算,那么我们换个思路,每个数会对哪些地方的 \(w(i,j)\) 值产生贡献?

记这个序列为 \(a_1,a_2,\ldots,a_n\),记 \(lst_i\)\(1 \sim i-1\)\(a_i\) 出现的最右侧位置的下标加一,那么 \(a_i\) 会对 \(j \in [lst_i, i], t \in [i, n]\)\(w(j, t)\) 产生 \(1\) 的贡献

然后上面的求 \(\max\) 就成了 \(f_{0, \ldots ,i - 1}\) 里面取最大值

这是一个区间加与区间查最大值,显然可以用线段树优化

那么直接暴力枚举 \(k\),每次先清空,然后用线段树维护 \(f_{k - 1, j - 1} + w(j, i)\) 的值

考虑让 \(i\)\(1\) 循环到 \(n\),先让 \([lst_i, i]\) 加一,然后在 \([0,i]\) 里面查询最大值记到 \(f_{k,i}\) 里面就行

时间复杂度 \(O(kn\log n)\),代码如下:

实现
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
typedef long long ll;
#define mid ((l + r) >> 1)
#define ls ((x) << 1)
#define rs ((x) << 1 | 1)
const int MAXN = 35010, inf = 0x7fffffff;
int dp[60][MAXN], lst[MAXN], pos[MAXN], tr[MAXN << 2], lazy[MAXN << 2];
void addtag(int x, int v)
{
    tr[x] += v;
    lazy[x] += v;
}
void pu(int x) { tr[x] = max(tr[ls], tr[rs]); }
void pd(int x)
{
    if (!lazy[x])
        return;
    addtag(ls, lazy[x]);
    addtag(rs, lazy[x]);
    lazy[x] = 0;
}
void build(int x, int l, int r, int now)
{
    tr[x] = lazy[x] = 0;
    if (l == r)
    {
        tr[x] = dp[now][l - 1];
        return;
    }
    build(ls, l, mid, now);
    build(rs, mid + 1, r, now);
    pu(x);
}
void update(int x, int l, int r, int pl, int pr, int val)
{
    if (pl <= l && r <= pr)
    {
        addtag(x, val);
        return;
    }
    pd(x);
    if (pl <= mid)
        update(ls, l, mid, pl, pr, val);
    if (mid < pr)
        update(rs, mid + 1, r, pl, pr, val);
    pu(x);
}
int query(int x, int l, int r, int pl, int pr)
{
    if (pl <= l && r <= pr)
        return tr[x];
    pd(x);
    int res = -inf;
    if (pl <= mid)
        res = max(res, query(ls, l, mid, pl, pr));
    if (mid < pr)
        res = max(res, query(rs, mid + 1, r, pl, pr));
    return res;
}
int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int n, k, t;
    cin >> n >> k;
    for (int i = 1; i <= n; i++)
    {
        cin >> t;
        lst[i] = pos[t] + 1;
        pos[t] = i;
    }
    for (int i = 1; i <= k; i++)
    {
        build(1, 0, n, i - 1);
        for (int j = 1; j <= n; j++)
        {
            update(1, 0, n, lst[j], j, 1);
            dp[i][j] = query(1, 0, n, 0, j);
        }
    }
    cout << dp[k][n];
    return 0;
}

再来一题

P2605 [ZJOI2010] 基站选址

\(N\) 个村庄坐落在一条直线上,第 \(i(i>1)\) 个村庄距离第 \(1\) 个村庄的距离为 \(D_i\)。需要在这些村庄中建立不超过 \(K\) 个通讯基站,在第 \(i\) 个村庄建立基站的费用为 \(C_i\)。如果在距离第 \(i\) 个村庄不超过 \(S_i\) 的范围内建立了一个通讯基站,那么就村庄被基站覆盖了。如果第 \(i\) 个村庄没有被覆盖,则需要向他们补偿,费用为 \(W_i\)。现在的问题是,选择基站的位置,使得总费用最小。

\(100\%\) 的数据中,\(K\leq N\)\(K\leq 100\)\(N\leq 2\times 10^4\)\(D_i \leq 10^9\)\(C_i\leq 10^4\)\(S_i \leq10^9\)\(W_i \leq 10^4\)

我们设 \(f_{i,j}\) 表示第 \(j\) 个基站建立在 \(i\) 位置,考虑 \([1, i]\) 位置产生的最小总费用。

容易写出转移方程

\[f_{i,k} = \min_{j = 1}^{k-1}\{f_{j,k - 1} + cost(j,i)\} + c_i\]

其中 \(cost(j,i)\)\([j + 1,i - 1]\) 区间中因为信号覆盖不到产生的总赔偿费用,\(c_i\) 为在 \(i\) 处建基站的费用。

看到式子里有 \(\min\),并且显然 \(1\le k\le i - 1\),很容易想到用线段树来处理,但是又发现这个 \(cost(k,i)\) 不太好维护。下面重点说说这个怎么处理。

我们不直接计算 \(cost(j,i)\),而是考虑这样一个问题,对于一个村庄 \(x\),什么时候它会产生赔偿 \(w_x\)

我们先预处理出最左边和最右边的信号能覆盖到 \(x\) 的村庄编号 \(L_x\)\(R_x\),那么当 \(k<L_x\)\(i>R_x\) 时,会产生赔偿 \(w_x\)。放在线段树上处理就是:当 \(i = R_x+1\) 时,我们就让线段树上 \([1, L_x-1]\) 的位置加上 \(w_x\),表示如果上一个基站建在 \([1,L_x-1]\) 的位置的话,那么 \(w_x\) 就会产生贡献了。 那么如何方便地找到刚好在 \(i - 1\) 位置的那些 \(R_x\) 呢?用一个vector来存就可以了。

这里运用一个小技巧,让 \(n \leftarrow n + 1\)\(k \leftarrow k + 1\),然后在新的 \(n\) 位置加一个虚拟村庄,最后答案就是 \(ans = f[n]\),因为是让第 \(k + 1\) 个虚拟基站建在这个虚拟村庄这里,用于收集答案。

实现
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define endl '\n'
#define mid ((l + r) >> 1)
#define ls ((x) << 1)
#define rs ((x) << 1 | 1)
inline void chkmn(ll &a, ll b) { a = a > b ? b : a; }
const int MAXN = 2e4 + 10;
const ll inf = 1e18;
vector<int> dr[MAXN];
ll tr[MAXN << 2], d[MAXN], c[MAXN], s[MAXN], L[MAXN], R[MAXN], w[MAXN], ans, lazy[MAXN << 2], f[MAXN];
int n, k;
void addtag(int x, ll v)
{
    tr[x] += v;
    lazy[x] += v;
}
void pu(int x) { tr[x] = min(tr[ls], tr[rs]); }
void pd(int x)
{
    if (lazy[x])
    {
        addtag(ls, lazy[x]);
        addtag(rs, lazy[x]);
        lazy[x] = 0;
    }
}
void build(int x, int l, int r)
{
    lazy[x] = 0;
    if (l >= r)
    {
        tr[x] = f[l];
        return;
    }
    build(ls, l, mid);
    build(rs, mid + 1, r);
    pu(x);
}
void update(int x, int l, int r, int pl, int pr, ll v)
{
    if (pl > pr)
        return;
    if (pl <= l && r <= pr)
        return addtag(x, v);
    pd(x);
    if (pl <= mid)
        update(ls, l, mid, pl, pr, v);
    if (mid < pr)
        update(rs, mid + 1, r, pl, pr, v);
    pu(x);
}
ll query(int x, int l, int r, int pl, int pr)
{
    if (pl > pr)
        return inf;
    if (pl <= l && r <= pr)
        return tr[x];
    ll res = inf;
    if (pl <= mid)
        chkmn(res, query(ls, l, mid, pl, pr));
    if (mid < pr)
        chkmn(res, query(rs, mid + 1, r, pl, pr));
    return res;
}
int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    cin >> n >> k;
    for (int i = 2; i <= n; i++)
        cin >> d[i];
    for (int i = 1; i <= n; i++)
        cin >> c[i];
    for (int i = 1; i <= n; i++)
        cin >> s[i];
    for (int i = 1; i <= n; i++)
        cin >> w[i];
    d[++n] = inf;
    ll cur = 0;
    for (int i = 1; i <= n; i++)
    {
        L[i] = lower_bound(d + 1, d + 1 + n, d[i] - s[i]) - d;
        R[i] = upper_bound(d + 1, d + 1 + n, d[i] + s[i]) - d - 1;
        dr[R[i]].push_back(i);
    }
    for (int i = 1; i <= n; i++)
    {
        f[i] = cur + c[i];
        for (int x : dr[i])
            cur += w[x];
    }
    ans = f[n];
    for (int i = 2; i <= k + 1; i++)
    {
        build(1, 1, n);
        for (int j = 1; j <= n; j++)
        {
            f[j] = query(1, 1, n, 1, j - 1) + c[j];
            for (int x : dr[j])
                update(1, 1, n, 1, L[x] - 1, w[x]);
        }
        chkmn(ans, f[n]);
    }
    cout << ans;
    return 0;
}

小练习

P9871 [NOIP2023] 天天爱打卡

小 T 同学非常热衷于跑步。为了让跑步更加有趣,他决定制作一款叫做《天天爱打卡》的软件,使得用户每天都可以进行跑步打卡。

开发完成后,小 T 同学计划进行试运行,他找了大 Y 同学来帮忙。试运行共 \(n\) 天,编号为从 \(1\)\(n\)

对大 Y 同学来说,如果某天他选择跑步打卡,那么他的能量值会减少 \(d\)。初始时,他的能量值是 \(0\),并且试运行期间他的能量值可以是负数

而且大 Y 不会连续跑步打卡超过 \(k\) 天;即不能存在 \(1\le x\le n-k\),使得他在第 \(x\) 到第 \(x+k\) 天均进行了跑步打卡。

小 T 同学在软件中设计了 \(m\) 个挑战,第 \(i\)\(1\le i \le m\))个挑战可以用三个正整数 \((x_i,y_i,v_i)\) 描述,表示如果在第 \(x_i\) 天时,用户已经连续跑步打卡至少 \(y_i\) 天(即第 \(x_i-y_i+1\) 到第 \(x_i\) 天均完成了跑步打卡),那么小 T 同学就会请用户吃饭,从而使用户的能量值提高 \(v_i\)

现在大 Y 想知道,在软件试运行的 \(n\) 天结束后,他的能量值最高可以达到多少?

本题有 \(t\) 组测试数据。

\(l_i=x_i-y_i+1\)\(r_i=x_i\) ​;

对于所有测试数据,保证: \(1\le t\le 10\)\(1\le k\le n\le 10^9\)\(1\le m\le 10^5\)\(1\le l_i\le r_i\le n\)\(1\le d,v_i\le 10^9\)

分析

考虑设dp状态 \(f_i\) 为考虑前 \(i\) 天且强制第 \(i\) 天跑步的最大值,同样设 \(g_i = \max_{j = 1}^{i} f_j\)。 易得一个转移方程:

\[f_i=\max_{j = i - k}^{i - 1}\{g_{j - 1}-(i - j)\cdot d+\sum_{[l_p,r_p]\subseteq(j,i]}v_p\}\]

也就是强制第 \(i\) 天跑步,且第 \(j\) 天不跑。

这个dp方程很经典,考虑使用线段树优化。 先考虑如何做到 \(O(n\log n)\)。 我们考虑线段树下标是 \(j\) 的决策点,存的东西是决策对应的值(记为 \(val_j\)),考虑当 \(i\) 变化时不同决策点的代价如何变化。

  1. 首先每一个决策点 \(j\)\(i\) 都比到 \(i - 1\) 要多跑一天的步,所以 \(val_j\) 要减去 \(d\)
  2. 然后这时候满足 \(j<l_p\)\(i\geq r_p\) 的挑战 \(p\) 变得可以选择(因为随着 \(i\) 的不断右移,只要满足上述条件这些挑战就始终能选择),也就是对于所有的 \(p\) 满足 \(j<l_p\)\(i\geq r_p\)\(val_j\) 要加上 \(v_p\)
  3. 然后就得到了 \(val_j\) 的变化方式,而 \(val_j\) 对比转移方程恰好在 \(j\) 作为决策点时取到,所以可以据此进行优化。

那么这就是一个 \(O(n\log n)\) 的做法。

然后考虑如何优化到 \(O((n + m)\log(n + m))\)。容易发现的是,我们在转移过程中,只有 \(l_p\)\(r_p\) 是有用的,所以我们考虑离散化,把 \(l_p\)\(r_p\) 丢到离散化数组 \(t\) 里,然后就做完了。 复杂度 \(O((n + m)\log(n + m))\)

实现
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
typedef long long ll;
const int MAXM = 1e5 + 10;
ll n, m, k, d, blen;
#define mid ((l + r) >> 1)
#define ls ((x) << 1)
#define rs ((x) << 1 | 1)
struct node
{
    ll l, r, v;
} a[MAXM];
ll b[MAXM << 1], f[MAXM << 1];
bool cmp(node a, node b) { return a.r < b.r; }
ll lazy[MAXM << 3], tr[MAXM << 3];
void pu(int x) { tr[x] = max(tr[ls], tr[rs]); }
void build(int x, int l, int r)
{
    lazy[x] = 0;
    if (l == r)
    {
        tr[x] = (l ? -1e18 : 0);
        return;
    }
    build(ls, l, mid), build(rs, mid + 1, r);
    pu(x);
}
void addtag(int x, ll v)
{
    tr[x] += v;
    lazy[x] += v;
}
void pd(int x)
{
    if (!lazy[x])
        return;
    addtag(ls, lazy[x]);
    addtag(rs, lazy[x]);
    lazy[x] = 0;
}
ll query(int x, int l, int r, int pl, int pr)
{
    if (pl <= l && r <= pr)
        return tr[x];
    pd(x);
    ll ans = -1e18;
    if (pl <= mid)
        ans = max(ans, query(ls, l, mid, pl, pr));
    if (mid < pr)
        ans = max(ans, query(rs, mid + 1, r, pl, pr));
    return ans;
}
void assign(int x, int l, int r, int p, ll val)
{
    if (l == r)
    {
        tr[x] = val;
        return;
    }
    pd(x);
    if (p <= mid)
        assign(ls, l, mid, p, val);
    else
        assign(rs, mid + 1, r, p, val);
    pu(x);
}
void update(int x, int l, int r, int pl, int pr, ll val)
{
    if (pl <= l && r <= pr)
    {
        addtag(x, val);
        return;
    }
    pd(x);
    if (pl <= mid)
        update(ls, l, mid, pl, pr, val);
    if (mid < pr)
        update(rs, mid + 1, r, pl, pr, val);
    pu(x);
}
void solve()
{
    cin >> n >> m >> k >> d;
    blen = 0;
    for (int i = 1; i <= m; i++)
    {
        cin >> a[i].r >> a[i].l >> a[i].v;
        a[i].l = a[i].r - a[i].l;
        b[++blen] = a[i].l, b[++blen] = a[i].r;
    }
    sort(b + 1, b + blen + 1);
    sort(a + 1, a + m + 1, cmp);
    blen = unique(b + 1, b + blen + 1) - b - 1;
    for (int i = 1; i <= m; i++)
    {
        a[i].l = lower_bound(b + 1, b + blen + 1, a[i].l) - b;
        a[i].r = lower_bound(b + 1, b + blen + 1, a[i].r) - b;
    }
    build(1, 0, blen);
    ll ans = 0;
    for (int i = 1, lst = 0, ridx = 1; i <= blen; i++)
    {
        assign(1, 0, blen, i, ans + d * b[i]);
        for (; ridx <= m && a[ridx].r == i; ridx++)
            update(1, 0, blen, 0, a[ridx].l, a[ridx].v);
        for (; lst < i && b[lst] < b[i] - k; lst++);
        if (lst < i)
            f[i] = query(1, 0, blen, lst, i - 1) - d * b[i];
        else
            f[i] = -1e18;
        ans = max(ans, f[i]);
    }
    cout << ans << endl;
}
int main()
{
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    int c, T;
    cin >> c >> T;
    while (T--)
        solve();
    return 0;
}