一些动态规划问题的转移可以转化为动态的区间修改与查询,这时便可以利用线段树来加速了
下面是一些例子
CF833B The Bakery
将一个长度为 \(n\) 的序列分为 \(k\) 段,使得总价值最大。
一段区间的价值表示为区间内不同数字的个数。
数据范围: \(n \leq 35000, k \leq 50\)
我们记 \(w(i, j)\) 为 \([l, r]\) 内不同数字的个数,那么设 \(f_{i,k}\) 为将前 \(i\) 个数分成 \(j\) 段所能得到的最大总价值
那么容易得出转移方程
\[f_{i,k + 1} = \max_{0 \le j < i}\{f_{j,k} + w(j + 1, i)\}\]
我们发现 \(w(i,j)\) 很难在短时间内重复计算,那么我们换个思路,每个数会对哪些地方的 \(w(i,j)\) 值产生贡献?
记这个序列为 \(a_1,a_2,\ldots,a_n\),记 \(lst_i\) 为 \(1 \sim i-1\) 中 \(a_i\) 出现的最右侧位置的下标加一,那么 \(a_i\) 会对 \(j \in [lst_i, i], t \in [i, n]\) 的 \(w(j, t)\) 产生 \(1\) 的贡献
然后上面的求 \(\max\) 就成了 \(f_{0, \ldots ,i - 1}\) 里面取最大值
这是一个区间加与区间查最大值,显然可以用线段树优化
那么直接暴力枚举 \(k\),每次先清空,然后用线段树维护 \(f_{k - 1, j - 1} + w(j, i)\) 的值
考虑让 \(i\) 从 \(1\) 循环到 \(n\),先让 \([lst_i, i]\) 加一,然后在 \([0,i]\) 里面查询最大值记到 \(f_{k,i}\) 里面就行
时间复杂度 \(O(kn\log n)\),代码如下:
实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84 | #include <bits/stdc++.h>
using namespace std;
#define endl '\n'
typedef long long ll;
#define mid ((l + r) >> 1)
#define ls ((x) << 1)
#define rs ((x) << 1 | 1)
const int MAXN = 35010, inf = 0x7fffffff;
int dp[60][MAXN], lst[MAXN], pos[MAXN], tr[MAXN << 2], lazy[MAXN << 2];
void addtag(int x, int v)
{
tr[x] += v;
lazy[x] += v;
}
void pu(int x) { tr[x] = max(tr[ls], tr[rs]); }
void pd(int x)
{
if (!lazy[x])
return;
addtag(ls, lazy[x]);
addtag(rs, lazy[x]);
lazy[x] = 0;
}
void build(int x, int l, int r, int now)
{
tr[x] = lazy[x] = 0;
if (l == r)
{
tr[x] = dp[now][l - 1];
return;
}
build(ls, l, mid, now);
build(rs, mid + 1, r, now);
pu(x);
}
void update(int x, int l, int r, int pl, int pr, int val)
{
if (pl <= l && r <= pr)
{
addtag(x, val);
return;
}
pd(x);
if (pl <= mid)
update(ls, l, mid, pl, pr, val);
if (mid < pr)
update(rs, mid + 1, r, pl, pr, val);
pu(x);
}
int query(int x, int l, int r, int pl, int pr)
{
if (pl <= l && r <= pr)
return tr[x];
pd(x);
int res = -inf;
if (pl <= mid)
res = max(res, query(ls, l, mid, pl, pr));
if (mid < pr)
res = max(res, query(rs, mid + 1, r, pl, pr));
return res;
}
int main()
{
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
int n, k, t;
cin >> n >> k;
for (int i = 1; i <= n; i++)
{
cin >> t;
lst[i] = pos[t] + 1;
pos[t] = i;
}
for (int i = 1; i <= k; i++)
{
build(1, 0, n, i - 1);
for (int j = 1; j <= n; j++)
{
update(1, 0, n, lst[j], j, 1);
dp[i][j] = query(1, 0, n, 0, j);
}
}
cout << dp[k][n];
return 0;
}
|
再来一题
P2605 [ZJOI2010] 基站选址
有 \(N\) 个村庄坐落在一条直线上,第 \(i(i>1)\) 个村庄距离第 \(1\) 个村庄的距离为 \(D_i\)。需要在这些村庄中建立不超过 \(K\) 个通讯基站,在第 \(i\) 个村庄建立基站的费用为 \(C_i\)。如果在距离第 \(i\) 个村庄不超过 \(S_i\) 的范围内建立了一个通讯基站,那么就村庄被基站覆盖了。如果第 \(i\) 个村庄没有被覆盖,则需要向他们补偿,费用为 \(W_i\)。现在的问题是,选择基站的位置,使得总费用最小。
\(100\%\) 的数据中,\(K\leq N\), \(K\leq 100\), \(N\leq 2\times 10^4\), \(D_i \leq 10^9\), \(C_i\leq 10^4\), \(S_i \leq10^9\), \(W_i \leq 10^4\)。
我们设 \(f_{i,j}\) 表示第 \(j\) 个基站建立在 \(i\) 位置,考虑 \([1, i]\) 位置产生的最小总费用。
容易写出转移方程
\[f_{i,k} = \min_{j = 1}^{k-1}\{f_{j,k - 1} + cost(j,i)\} + c_i\]
其中 \(cost(j,i)\) 为 \([j + 1,i - 1]\) 区间中因为信号覆盖不到产生的总赔偿费用,\(c_i\) 为在 \(i\) 处建基站的费用。
看到式子里有 \(\min\),并且显然 \(1\le k\le i - 1\),很容易想到用线段树来处理,但是又发现这个 \(cost(k,i)\) 不太好维护。下面重点说说这个怎么处理。
我们不直接计算 \(cost(j,i)\),而是考虑这样一个问题,对于一个村庄 \(x\),什么时候它会产生赔偿 \(w_x\)?
我们先预处理出最左边和最右边的信号能覆盖到 \(x\) 的村庄编号 \(L_x\) 和 \(R_x\),那么当 \(k<L_x\) 且 \(i>R_x\) 时,会产生赔偿 \(w_x\)。放在线段树上处理就是:当 \(i = R_x+1\) 时,我们就让线段树上 \([1, L_x-1]\) 的位置加上 \(w_x\),表示如果上一个基站建在 \([1,L_x-1]\) 的位置的话,那么 \(w_x\) 就会产生贡献了。 那么如何方便地找到刚好在 \(i - 1\) 位置的那些 \(R_x\) 呢?用一个vector来存就可以了。
这里运用一个小技巧,让 \(n \leftarrow n + 1\), \(k \leftarrow k + 1\),然后在新的 \(n\) 位置加一个虚拟村庄,最后答案就是 \(ans = f[n]\),因为是让第 \(k + 1\) 个虚拟基站建在这个虚拟村庄这里,用于收集答案。
实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107 | #include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define endl '\n'
#define mid ((l + r) >> 1)
#define ls ((x) << 1)
#define rs ((x) << 1 | 1)
inline void chkmn(ll &a, ll b) { a = a > b ? b : a; }
const int MAXN = 2e4 + 10;
const ll inf = 1e18;
vector<int> dr[MAXN];
ll tr[MAXN << 2], d[MAXN], c[MAXN], s[MAXN], L[MAXN], R[MAXN], w[MAXN], ans, lazy[MAXN << 2], f[MAXN];
int n, k;
void addtag(int x, ll v)
{
tr[x] += v;
lazy[x] += v;
}
void pu(int x) { tr[x] = min(tr[ls], tr[rs]); }
void pd(int x)
{
if (lazy[x])
{
addtag(ls, lazy[x]);
addtag(rs, lazy[x]);
lazy[x] = 0;
}
}
void build(int x, int l, int r)
{
lazy[x] = 0;
if (l >= r)
{
tr[x] = f[l];
return;
}
build(ls, l, mid);
build(rs, mid + 1, r);
pu(x);
}
void update(int x, int l, int r, int pl, int pr, ll v)
{
if (pl > pr)
return;
if (pl <= l && r <= pr)
return addtag(x, v);
pd(x);
if (pl <= mid)
update(ls, l, mid, pl, pr, v);
if (mid < pr)
update(rs, mid + 1, r, pl, pr, v);
pu(x);
}
ll query(int x, int l, int r, int pl, int pr)
{
if (pl > pr)
return inf;
if (pl <= l && r <= pr)
return tr[x];
ll res = inf;
if (pl <= mid)
chkmn(res, query(ls, l, mid, pl, pr));
if (mid < pr)
chkmn(res, query(rs, mid + 1, r, pl, pr));
return res;
}
int main()
{
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
cin >> n >> k;
for (int i = 2; i <= n; i++)
cin >> d[i];
for (int i = 1; i <= n; i++)
cin >> c[i];
for (int i = 1; i <= n; i++)
cin >> s[i];
for (int i = 1; i <= n; i++)
cin >> w[i];
d[++n] = inf;
ll cur = 0;
for (int i = 1; i <= n; i++)
{
L[i] = lower_bound(d + 1, d + 1 + n, d[i] - s[i]) - d;
R[i] = upper_bound(d + 1, d + 1 + n, d[i] + s[i]) - d - 1;
dr[R[i]].push_back(i);
}
for (int i = 1; i <= n; i++)
{
f[i] = cur + c[i];
for (int x : dr[i])
cur += w[x];
}
ans = f[n];
for (int i = 2; i <= k + 1; i++)
{
build(1, 1, n);
for (int j = 1; j <= n; j++)
{
f[j] = query(1, 1, n, 1, j - 1) + c[j];
for (int x : dr[j])
update(1, 1, n, 1, L[x] - 1, w[x]);
}
chkmn(ans, f[n]);
}
cout << ans;
return 0;
}
|
小练习
P9871 [NOIP2023] 天天爱打卡
小 T 同学非常热衷于跑步。为了让跑步更加有趣,他决定制作一款叫做《天天爱打卡》的软件,使得用户每天都可以进行跑步打卡。
开发完成后,小 T 同学计划进行试运行,他找了大 Y 同学来帮忙。试运行共 \(n\) 天,编号为从 \(1\) 到 \(n\)。
对大 Y 同学来说,如果某天他选择跑步打卡,那么他的能量值会减少 \(d\)。初始时,他的能量值是 \(0\),并且试运行期间他的能量值可以是负数。
而且大 Y 不会连续跑步打卡超过 \(k\) 天;即不能存在 \(1\le x\le n-k\),使得他在第 \(x\) 到第 \(x+k\) 天均进行了跑步打卡。
小 T 同学在软件中设计了 \(m\) 个挑战,第 \(i\)( \(1\le i \le m\))个挑战可以用三个正整数 \((x_i,y_i,v_i)\) 描述,表示如果在第 \(x_i\) 天时,用户已经连续跑步打卡至少 \(y_i\) 天(即第 \(x_i-y_i+1\) 到第 \(x_i\) 天均完成了跑步打卡),那么小 T 同学就会请用户吃饭,从而使用户的能量值提高 \(v_i\)。
现在大 Y 想知道,在软件试运行的 \(n\) 天结束后,他的能量值最高可以达到多少?
本题有 \(t\) 组测试数据。
记 \(l_i=x_i-y_i+1\), \(r_i=x_i\) ;
对于所有测试数据,保证: \(1\le t\le 10\), \(1\le k\le n\le 10^9\), \(1\le m\le 10^5\), \(1\le l_i\le r_i\le n\), \(1\le d,v_i\le 10^9\)。
分析
考虑设dp状态 \(f_i\) 为考虑前 \(i\) 天且强制第 \(i\) 天跑步的最大值,同样设 \(g_i = \max_{j = 1}^{i} f_j\)。
易得一个转移方程:
\[f_i=\max_{j = i - k}^{i - 1}\{g_{j - 1}-(i - j)\cdot d+\sum_{[l_p,r_p]\subseteq(j,i]}v_p\}\]
也就是强制第 \(i\) 天跑步,且第 \(j\) 天不跑。
这个dp方程很经典,考虑使用线段树优化。
先考虑如何做到 \(O(n\log n)\)。
我们考虑线段树下标是 \(j\) 的决策点,存的东西是决策对应的值(记为 \(val_j\)),考虑当 \(i\) 变化时不同决策点的代价如何变化。
- 首先每一个决策点 \(j\) 到 \(i\) 都比到 \(i - 1\) 要多跑一天的步,所以 \(val_j\) 要减去 \(d\)。
- 然后这时候满足 \(j<l_p\) 且 \(i\geq r_p\) 的挑战 \(p\) 变得可以选择(因为随着 \(i\) 的不断右移,只要满足上述条件这些挑战就始终能选择),也就是对于所有的 \(p\) 满足 \(j<l_p\) 且 \(i\geq r_p\), \(val_j\) 要加上 \(v_p\)。
- 然后就得到了 \(val_j\) 的变化方式,而 \(val_j\) 对比转移方程恰好在 \(j\) 作为决策点时取到,所以可以据此进行优化。
那么这就是一个 \(O(n\log n)\) 的做法。
然后考虑如何优化到 \(O((n + m)\log(n + m))\)。容易发现的是,我们在转移过程中,只有 \(l_p\) 和 \(r_p\) 是有用的,所以我们考虑离散化,把 \(l_p\) 和 \(r_p\) 丢到离散化数组 \(t\) 里,然后就做完了。
复杂度 \(O((n + m)\log(n + m))\)。
实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124 | #include <bits/stdc++.h>
using namespace std;
#define endl '\n'
typedef long long ll;
const int MAXM = 1e5 + 10;
ll n, m, k, d, blen;
#define mid ((l + r) >> 1)
#define ls ((x) << 1)
#define rs ((x) << 1 | 1)
struct node
{
ll l, r, v;
} a[MAXM];
ll b[MAXM << 1], f[MAXM << 1];
bool cmp(node a, node b) { return a.r < b.r; }
ll lazy[MAXM << 3], tr[MAXM << 3];
void pu(int x) { tr[x] = max(tr[ls], tr[rs]); }
void build(int x, int l, int r)
{
lazy[x] = 0;
if (l == r)
{
tr[x] = (l ? -1e18 : 0);
return;
}
build(ls, l, mid), build(rs, mid + 1, r);
pu(x);
}
void addtag(int x, ll v)
{
tr[x] += v;
lazy[x] += v;
}
void pd(int x)
{
if (!lazy[x])
return;
addtag(ls, lazy[x]);
addtag(rs, lazy[x]);
lazy[x] = 0;
}
ll query(int x, int l, int r, int pl, int pr)
{
if (pl <= l && r <= pr)
return tr[x];
pd(x);
ll ans = -1e18;
if (pl <= mid)
ans = max(ans, query(ls, l, mid, pl, pr));
if (mid < pr)
ans = max(ans, query(rs, mid + 1, r, pl, pr));
return ans;
}
void assign(int x, int l, int r, int p, ll val)
{
if (l == r)
{
tr[x] = val;
return;
}
pd(x);
if (p <= mid)
assign(ls, l, mid, p, val);
else
assign(rs, mid + 1, r, p, val);
pu(x);
}
void update(int x, int l, int r, int pl, int pr, ll val)
{
if (pl <= l && r <= pr)
{
addtag(x, val);
return;
}
pd(x);
if (pl <= mid)
update(ls, l, mid, pl, pr, val);
if (mid < pr)
update(rs, mid + 1, r, pl, pr, val);
pu(x);
}
void solve()
{
cin >> n >> m >> k >> d;
blen = 0;
for (int i = 1; i <= m; i++)
{
cin >> a[i].r >> a[i].l >> a[i].v;
a[i].l = a[i].r - a[i].l;
b[++blen] = a[i].l, b[++blen] = a[i].r;
}
sort(b + 1, b + blen + 1);
sort(a + 1, a + m + 1, cmp);
blen = unique(b + 1, b + blen + 1) - b - 1;
for (int i = 1; i <= m; i++)
{
a[i].l = lower_bound(b + 1, b + blen + 1, a[i].l) - b;
a[i].r = lower_bound(b + 1, b + blen + 1, a[i].r) - b;
}
build(1, 0, blen);
ll ans = 0;
for (int i = 1, lst = 0, ridx = 1; i <= blen; i++)
{
assign(1, 0, blen, i, ans + d * b[i]);
for (; ridx <= m && a[ridx].r == i; ridx++)
update(1, 0, blen, 0, a[ridx].l, a[ridx].v);
for (; lst < i && b[lst] < b[i] - k; lst++);
if (lst < i)
f[i] = query(1, 0, blen, lst, i - 1) - d * b[i];
else
f[i] = -1e18;
ans = max(ans, f[i]);
}
cout << ans << endl;
}
int main()
{
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
int c, T;
cin >> c >> T;
while (T--)
solve();
return 0;
}
|