数位 DP
引入
在 OI 中,一些题目会给定一个区间 \([L, R]\) ,要求我们统计其中符合一些性质的数的个数。
由于给定的区间长度可能会很大(例如 \(10 ^ 9\) 、 \(10 ^ {18}\) 甚至是 \(10 ^ {10 ^ 5}\) ),因此我们需要借助计数类 DP 来完成,这就是数位 DP。
框架
首先,我们将区间 \([L, R]\) 拆成 \([1, R] - [1, L - 1]\) ,或者 \([1, R] - [1, L] + {L}\) 。
然后就转成了对 \([1, x]\) 计数。
一般情况下,对于数位 DP 而言,记忆化搜索是一个比较简单的实现方式。
我们先将 \(x\) 在给定进制下分解为 \(len\) 位置的数,然后我们考虑从高到低一位一位填上去。
int len , num [ MAXN ], f [ MAXN ][ MAXN ][...];
int dp ( int d , bool up , ...) // up 判断是否前面都按照上限填
{
if ( ! up && ~ f [ d ][...]) return f [ d ][...];
int res = 0 , lim = up ? num [ d ] : 9 ; // 该位能填的最高数码
rep ( i , 0 , lim ) if ( conditions (...)) add ( res , dp ( d + 1 , up && i == lim , ...));
if ( up ) return res ;
return f [ d ][...] = res ;
}
还是比较好懂的。
例题
P4127 同类分布
给出两个数 \(a, b\) ,求出 \([a, b]\) 中各位数字之和能整除原数的数的个数,其中 \(1 \leq a \leq b \leq 10^{18}\) 。
我们考虑枚举这个各位数字和 \(s\) ,将原数对 \(s\) 取模的结果,以及当前各位和计入状态即可。
实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36 int num [ 20 ], len , f [ 20 ][ 170 ][ 170 ], tar ;
int dp ( int d , bool up , int sum , int now )
{
if ( sum > tar || sum + ( len - d + 1 ) * 9 < tar ) return 0 ; // 小剪枝,这也是记忆化搜索的一个优点。
if ( d == len + 1 ) return tar == sum && ! now ;
if ( ! up && ~ f [ d ][ sum ][ now ]) return f [ d ][ sum ][ now ];
int res = 0 , lim = up ? num [ d ] : 9 ;
rep ( i , 0 , lim ) res += dp ( d + 1 , up && i == lim , sum + i , ( now * 10 + i ) % tar );
if ( up ) return res ;
return f [ d ][ sum ][ now ] = res ;
}
bool chk ( string s )
{
int num = 0 , sum = 0 ;
for ( char c : s ) c -= '0' , num = num * 10 + c , sum += c ;
return num % sum == 0 ;
}
int solve ( const string & s )
{
Mst ( f , -1 ), len = 0 ;
for ( char c : s ) num [ ++ len ] = c - '0' ;
return dp ( 1 , 1 , 0 , 0 );
}
int calc ( const string & sl , const string & sr , int x )
{
tar = x ;
return solve ( sr ) - solve ( sl );
}
void Solve ()
{
string sl , sr ;
cin >> sl >> sr ;
int ans = chk ( sl );
rep ( i , 1 , 162 ) ans += calc ( sl , sr , i );
cout << ans ;
}
P6128 [USACO06NOV] Round Numbers S
如果一个正整数的二进制表示中,\(0\) 的数目不小于 \(1\) 的数目,那么它就被称为「圆数」。
例如,\(9\) 的二进制表示为 \(1001\) ,其中有 \(2\) 个 \(0\) 与 \(2\) 个 \(1\) 。因此,\(9\) 是一个「圆数」。
请你计算,区间 \([l,r]\) 中有多少个「圆数」,其中 \(1\le l,r\le 2\times 10^9\) 。
变成了二进制,不过本质没有区别。
这里注意一下前导零的事情,记录一下即可。
实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26 int len , f [ 36 ][ 36 ][ 36 ], num [ 36 ];
int dp ( int d , bool up , bool lead , int zero , int one )
{
if ( zero + len - d + 1 < one ) return 0 ;
if ( d == len + 1 ) return zero >= one ;
if ( ! up && ! lead && ~ f [ d ][ zero ][ one ]) return f [ d ][ zero ][ one ];
int lim = up ? num [ d ] : 1 , res = 0 ;
res += dp ( d + 1 , up && ! lim , lead , lead ? 0 : zero + 1 , one );
if ( lim ) res += dp ( d + 1 , up && lim , 0 , zero , one + 1 );
if ( up || lead ) return res ;
return f [ d ][ zero ][ one ] = res ;
}
int solve ( int x )
{
len = 0 ;
while ( x ) num [ ++ len ] = x & 1 , x >>= 1 ;
reverse ( num + 1 , num + 1 + len );
Mst ( f , -1 );
return dp ( 1 , 1 , 1 , 0 , 0 );
}
void Solve ()
{
int l , r ;
cin >> l >> r ;
cout << solve ( r ) - solve ( l - 1 );
}