动态规划

本文内容参考《算法导论》,OI Wiki

基础知识

动态规划方法通常用来求解最优化问题。这类问题可以有很多可行解,每个解都有一个值,我们希望寻找具有最优值(最小值或最大值)的解。

适合应用动态规划方法求解的最优化问题应该具备两个要素:最优子结构重叠子问题

  • 最优子结构:问题的最优解由相关子问题的最优解组合而成,而这些子问题可以独立求解。
  • 重叠子问题:如果问题的递归算法会反复求解相同的子问题,我们就称最优化问题具有重叠子问题性质。

子问题图

子问题图是一个有向图,每个顶点唯一的对应一个子问题。若求子问题 \(x\) 的最优解时需要直接用到子问题 \(y\) 的最优解,那么在子问题图中就会有一条从子问题 \(x\) 的顶点到子问题 \(y\) 的顶点的有向边。

自顶向下动态规划处理子问题图中顶点的顺序为拓扑序,自底向上动态规划处理子问题图中顶点的顺序为逆拓扑序。

通常情况下,动态规划算法的运行时间与子问题图中顶点和边的数量呈线性关系。

选择自顶向下,还是自底向上

通常情况下,如果每个子问题都必须至少求解一次,自底向上动态规划算法会更快,因为没有递归调用的开销,而且对于某些问题,可以利用表的访问模式降低时空开销。如果子问题空间中的某些子问题完全不必求解,自顶向下动态规划算法会更快,因为它只会求解那些必要的子问题。

背包 DP

题目:有 \(n\) 种物品和一个容量为 \(W\) 的背包,每种物品有数量 \(k_{i}\)、重量 \(w_{i}\) 和价值 \(v_{i}\) 三种属性,要求选若干物品放入背包使背包中物品的总价值最大且背包中物品的总重量不超过背包的容量。

0-1 背包

每种物品只能取一次,即 \(k_{i}=1\) 对任意 \(i\) 都成立。

转移方程:

$$ dp[i][j]=\max{(dp[i-1][j],dp[i-1][j-w[i]]+v[i])} $$

空间优化(倒序枚举):

$$ dp[j]=\max{(dp[j],dp[j-w[i]]+v[i])} $$

完全背包

每种物品可以取无限次,即 \(k_{i}=\infty\) 对任意 \(i\) 都成立。

转移方程:

$$ dp[i][j]=\max_{k=0}^{\infty}{dp[i-1][j-k\cdot w[i]]+k\cdot v[i]} $$

方程优化:

$$ dp[i][j]=\max{(dp[i-1][j],dp[i][j-w[i]]+v[i])} $$

空间优化(正序枚举):

$$ dp[j]=\max(dp[j],dp[j-w[i]]+v[i]) $$

多重背包

每种物品可以取 \(k_{i}\) 次,即 \(k_{i}\in\mathbb{N}\) 对任意 \(i\) 都成立。

转移方程:

$$ dp[i][j]=\max_{k=0}^{k[i]}{dp[i-1][j-k\cdot w[i]]+k\cdot v[i]} $$

二进制分组优化:将每种物品的 \(k_{i}\) 拆分为多个组,每组的数量为 \(2^{0},2^{1},\dots,2^{\lfloor{\log{k_{i}+1}}\rfloor -1}\),如果 \(k_{i}+1\) 不是二的幂,就将多余的数量作为一组,最后将 \(k_{i}\) 拆出来的每组都看作数量为 \(1\) 的新物品,从而转化为 0-1 背包。可以证明,如果选择 \(x\) 次第 \(i\) 种物品,其中 \(x\in[0,k_{i}]\),则该选择方式总是可以由分组后的新物品的某个组合表示。

例题

Codeforces Round 905 (Div. 2)

Chemistry

只要奇数字母的个数不大于 \(k+1\) 即可,因为回文串最多有一个奇数字母。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public static void solve() {
int n = io.nextInt(), k = io.nextInt();
String s = io.next();

int[] cnt = new int[26];
for (int i = 0; i < n; i++) {
cnt[s.charAt(i) - 'a']++;
}

int sum = 0;
for (int x : cnt) {
if (x % 2 == 1) {
sum++;
}
}
io.println(sum - 1 > k ? "NO" : "YES");
}

Raspberries

当 \(k=2,3,5\) 时,因为 \(k\) 是质数,如果所有数的乘积能够被 \(k\) 整除,必定存在一个数能够被 \(k\) 整除,所以单独计算每个数即可。当 \(k=4\) 时,需要计算存在一个数能被 \(4\) 整除的最少操作数,还需要计算存在两个能被 \(2\) 整除的数的最少操作数,答案为两者的最小值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public static void solve() {
int n = io.nextInt(), k = io.nextInt();

int cnt = 0;
int[] a = new int[n];
for (int i = 0; i < n; i++) {
a[i] = io.nextInt();
if (a[i] % 2 == 0) {
cnt++;
}
}

int ans = k;
for (int i = 0; i < n; i++) {
ans = Math.min(ans, (k - a[i] % k) % k);
}
if (k == 4) {
ans = Math.min(ans, Math.max(0, 2 - cnt));
}
io.println(ans);
}

You Are So Beautiful

如果某个子数组作为子序列只出现过一次,因为子数组本身就是子序列,所以没有其他方式能够构成该子数组,即子数组的左端点左边没有和它相同的数,右端点的右边也没有和它相同的数。我们可以使用集合 + 前缀和的方式预先计算每个位置及其左边满足条件的左端点个数,然后倒序处理数组,对每个满足条件的右端点,都将其对应的左端点的个数添加到答案。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public static void solve() {
int n = io.nextInt();
int[] a = new int[n];
for (int i = 0; i < n; i++) {
a[i] = io.nextInt();
}

Set<Integer> set = new HashSet<>();
int[] prefix = new int[n + 1];
for (int i = 0; i < n; i++) {
prefix[i + 1] = prefix[i] + (set.add(a[i]) ? 1 : 0);
}
set.clear();

long ans = 0L;
for (int i = n - 1; i >= 0; i--) {
if (set.add(a[i])) {
ans += prefix[i + 1];
}
}
io.println(ans);
}

Dances (Easy version)

题目真难读,简单版只会在数组 \(a\) 中添加一个 \(1\),然后计算最少操作数,可以使用排序 + 双指针进行处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
public static void solve() {
int n = io.nextInt(), m = io.nextInt();
int[] a = new int[n];
a[0] = 1;
for (int i = 1; i < n; i++) {
a[i] = io.nextInt();
}
int[] b = new int[n];
for (int i = 0; i < n; i++) {
b[i] = io.nextInt();
}

Arrays.sort(a);
Arrays.sort(b);

int k = 0;
for (int i = 0, j = 0; i < n - k; i++, j++) {
while (j < n && a[i] >= b[j]) {
k++;
j++;
}
}
io.println(k);
}

Dances (Hard Version)

困难版,计算在数组中分别添加 \([1,m]\) 需要的最少操作数。通过观察可以发现(真发现不了),改变 \(a[0]\) 最多只会使操作次数加 \(1\),所以我们可以二分该边界值,然后计算答案即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
public static void solve() {
int n = io.nextInt(), m = io.nextInt();
int[] a = new int[n];
for (int i = 1; i < n; i++) {
a[i] = io.nextInt();
}
int[] b = new int[n];
for (int i = 0; i < n; i++) {
b[i] = io.nextInt();
}
Arrays.sort(b);

int k = calc(a, b, 1);

int lo = 1, hi = m;
while (lo <= hi) {
int mid = lo + (hi - lo) / 2;
if (calc(a, b, mid) == k) lo = mid + 1;
else hi = mid - 1;
}
io.println((long) m * k + (m - lo + 1));
}

private static int calc(int[] a, int[] b, int x) {
a[0] = x;
a = a.clone();
Arrays.sort(a);

int n = a.length, k = 0;
for (int i = 0, j = 0; i < n - k; i++, j++) {
while (j < n && a[i] >= b[j]) {
k++;
j++;
}
}
return k;
}

竟然还有更简单的方法,首先计算 \(a\) 中 \(n-1\) 个数对应 \(b\) 中 \(n\) 个数的最少删除次数,并同时维护 \(b\) 中不满足 \(a[i]<b[j]\) 的最后一个值,该值就是操作次数的分界点,直接计算答案即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public static void solve() {
int n = io.nextInt(), m = io.nextInt();
int[] a = new int[n];
a[0] = Integer.MAX_VALUE;
for (int i = 1; i < n; i++) {
a[i] = io.nextInt();
}
int[] b = new int[n];
for (int i = 0; i < n; i++) {
b[i] = io.nextInt();
}

Arrays.sort(a);
Arrays.sort(b);

int k = 0, val = m + 1;
for (int i = 0, j = 0; i < n - k; i++, j++) {
while (j < n && a[i] >= b[j]) {
val = b[j];
k++;
j++;
}
}
io.println((long) m * (k - 1) + Math.max(0, m - val + 1));
}

Codeforces Round 904 (Div. 2)

Simple Design

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public static void solve() {
int x = io.nextInt(), k = io.nextInt();
while (sum(x) % k != 0) {
x++;
}
io.println(x);
}

private static int sum(int x) {
int res = 0;
while (x != 0) {
res += x % 10;
x /= 10;
}
return res;
}

Haunted House

好难啊,做得很慢。对于每个 \(i\),如果它是满足条件的,那么 \([n-i,n-1]\) 需要全为 \(0\),它的最少操作次数为 \([n-i,n-1]\) 中所有值为 \(1\) 的下标和,减去 \([0,n-i-1]\) 中最近的值为 \(0\) 的对应个数的下标和。我们可以使用双指针 \(O(n)\) 的计算所有 \(i\),具体见代码。指针 \(j\) 枚举每个下标,同时求出后缀的下标和,指针 \(i\) 指向指针 \(j\) 需要的最远的 \(0\) 的下标位置,同时求出后缀值为 \(0\) 的下标和,它们的差值就是 \(j\) 的最少操作次数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public static void solve() {
int n = io.nextInt();
String s = io.next();

long sum = 0L;
int i, j, cnt = 0;
for (i = n - 1, j = n - 1; i >= 0; j--) {
cnt++;
sum += j;
for (; i >= 0 && cnt > 0; i--) {
if (s.charAt(i) == '0') {
cnt--;
sum -= i;
}
}
io.print(cnt > 0 ? "-1 " : sum + " ");
}
io.println("-1 ".repeat(j + 1));
}

发现一个超级简单的写法,基本思路就是从低到高放置 \(0\),操作次数即为 \(0\) 的移动次数,废话不多说,代码很好懂。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public static void solve() {
int n = io.nextInt();
char[] s = io.next().toCharArray();

long ans = 0L;
int l = n - 1, r = n - 1;
for (; l >= 0; l--) {
if (s[l] == '0') {
ans += r - l;
r--;
io.print(ans + " ");
}
}
io.println("-1 ".repeat(r + 1));
}

Medium Design

最小值一定在位置 \(1\) 或位置 \(m\),我们可以考虑处理区间不包含 \(1\) 和不包含 \(m\) 两种情况下,能够得到的最大值,根据简单的推导可以知道问题是等价的。如何计算最大值,根据题解所说似乎是扫描线算法,使用 \((l,1)\) 表示进入某个区间,\((r,-1)\) 表示离开某个区间,注意初始时我们将每个左端点减 \(1\),表示从区间 \(0\) 开始算,\((l,r)\) 是左闭右开区间,所以 \(r\) 表示离开某个区间,然后每当处理完某个端点就更新答案。(其实也可以使用差分哈希表进行区间求和)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
public static void solve() {
int n = io.nextInt(), m = io.nextInt();
int[] l = new int[n];
int[] r = new int[n];
for (int i = 0; i < n; i++) {
l[i] = io.nextInt() - 1;
r[i] = io.nextInt();
}
// 第一次扫描,不包含第一个位置
List<int[]> list = new ArrayList<>();
for (int i = 0; i < n; i++) {
if (l[i] != 0) {
list.add(new int[]{l[i], 1});
list.add(new int[]{r[i], -1});
}
}
list.sort((a, b) -> a[0] - b[0]);
int ans = sweep(list);
// 第二次扫描,不包含最后一个位置
list.clear();
for (int i = 0; i < n; i++) {
if (r[i] != m) {
list.add(new int[]{l[i], 1});
list.add(new int[]{r[i], -1});
}
}
list.sort((a, b) -> a[0] - b[0]);
ans = Math.max(ans, sweep(list));
io.println(ans);
}

private static int sweep(List<int[]> list) {
// cnt 表示在多少个区间内
// lst 表示上次处理的端点
int res = 0, cnt = 0, lst = 0;
for (int[] t : list) {
if (t[0] > lst) {
res = Math.max(res, cnt);
}
cnt += t[1];
lst = t[0];
}
return res;
}

Counting Rhyme

一对数 \(x,y\) 不能同时被数组中的数整除,即 \(\gcd(x,y)\) 不能被数组中的数整除。我们首先可以计算出数组中有多少对数它们的 \(\gcd=1,2,3\dots,n\),然后排除掉能够被数组中的数整除的 \(\gcd\),剩下的 \(\gcd\) 对应的对数之和就是答案。第一步可以使用动态规划求解,转移方程如下:

$$ sum = cnt[i]+cnt[2\times i]+\cdots+cnt[k\times i] \\ dp[i]= \frac{sum\times (sum-1)}{2}-(dp[2\times i]+dp[3\times i]+\cdots+dp[k\times i]) $$

第二步切记不能枚举数组中的数来排除,这样在所有值都为 \(1\) 的样例下时间复杂度会达到 \(O(n^{2})\),除非将数组去重,或者像下面代码一样枚举。最后计算答案即可。本题的另一种解法是 GCD 卷积,暂时不学。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
public static void solve() {
int n = io.nextInt();
int[] a = new int[n];
int[] cnt = new int[n + 1];
for (int i = 0; i < n; i++) {
a[i] = io.nextInt();
cnt[a[i]]++;
}

long[] dp = new long[n + 1];
for (int i = n; i > 0; i--) {
long tot = 0L;
for (int j = i; j <= n; j += i) {
tot += cnt[j];
dp[i] -= dp[j];
}
dp[i] += tot * (tot - 1) / 2;
}

for (int i = 1; i <= n; i++) {
if (cnt[i] != 0) {
for (int j = i; j <= n; j += i) {
dp[j] = 0;
}
}
}

long ans = 0L;
for (int i = 1; i <= n; i++) {
ans += dp[i];
}
io.println(ans);
}