图论

本文内容参考《算法》,《算法导论》,OI Wiki

拓扑排序

例题

实现

  • 时间复杂度为 \(O(n+m)\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
private static int[] topologicalSort(int n, int[][] edges) {
int[] in = new int[n];
List<Integer>[] g = new List[n];
Arrays.setAll(g, k -> new ArrayList<>());
for (var e : edges) {
int u = e[0], v = e[1];
g[u].add(v);
in[v]++;
}

Queue<Integer> q = new ArrayDeque<>();
for (int i = 0; i < n; i++) {
if (in[i] == 0) {
q.offer(i);
}
}

int idx = 0;
int[] res = new int[n];
while (!q.isEmpty()) {
int x = q.poll();
res[idx++] = x;
for (int y : g[x]) {
if (--in[y] == 0) {
q.offer(y);
}
}
}

// 拓扑排序不存在
assert idx == n;

return res;
}

最小生成树

例题

Prim

实现一:朴素版本

  • 时间复杂度为 \(O(n^{2})\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
private static int prim(int n, int[][] edges) {
int[][] g = new int[n][n];
for (int i = 0; i < n; i++) {
Arrays.fill(g[i], Integer.MAX_VALUE);
g[i][i] = 0;
}
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
if (g[u][v] > w) {
g[u][v] = g[v][u] = w;
}
}

int[] d = new int[n];
Arrays.fill(d, Integer.MAX_VALUE);
boolean[] vis = new boolean[n];

int res = 0;
d[0] = 0;
for (int i = 0; i < n; i++) {
int t = -1;
for (int j = 0; j < n; j++) {
if (!vis[j] && (t == -1 || d[t] > d[j])) {
t = j;
}
}

// 不是连通图,最小生成树不存在
assert d[t] != Integer.MAX_VALUE;

vis[t] = true;
res += d[t];

for (int j = 0; j < n; j++) {
d[j] = Math.min(d[j], g[t][j]);
}
}

return res;
}

实现二:优先队列优化

  • 时间复杂度为 \(O(m\log{m})\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
private static int prim(int n, int[][] edges) {
List<int[]>[] g = new List[n];
Arrays.setAll(g, k -> new ArrayList<>());
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
g[u].add(new int[]{v, w});
g[v].add(new int[]{u, w});
}

int[] d = new int[n];
Arrays.fill(d, Integer.MAX_VALUE);
boolean[] vis = new boolean[n];
Queue<int[]> q = new PriorityQueue<>((a, b) -> a[1] - b[1]);

int res = 0, cnt = 0;
d[0] = 0;
q.offer(new int[]{0, 0});
while (!q.isEmpty()) {
int u = q.poll()[0];
if (vis[u]) continue;
vis[u] = true;
res += d[u];
if (++cnt == n) break;
for (int[] t : g[u]) {
int v = t[0], w = t[1];
if (!vis[v] && d[v] > w) {
d[v] = w;
q.offer(new int[]{v, d[v]});
}
}
}

// 不是连通图,最小生成树不存在
assert cnt == n;

return res;
}

Kruskal

  • 时间复杂度为 \(O(m\log{m})\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
private static int kruskal(int n, int[][] edges) {
Arrays.sort(edges, (a, b) -> a[2] - b[2]);

int cnt = 1, res = 0;
UnionFind uf = new UnionFind(n);
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
if (uf.connected(u, v)) continue;
uf.union(u, v);
res += w;
if (++cnt == n) break;
}

// 不是连通图,最小生成树不存在
assert cnt == n;

return res;
}

最短路

例题

Dijkstra

  • 使用场景:解决边权非负的单源最短路问题。

实现一:朴素版本

  • 时间复杂度为 \(O(n^{2})\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
private static final int INF = (int) 1e9;

private static int dijkstra(int n, int[][] edges) {
int[][] g = new int[n][n];
for (int i = 0; i < n; i++) {
Arrays.fill(g[i], INF);
g[i][i] = 0;
}
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
g[u][v] = Math.min(g[u][v], w);
}

int[] d = new int[n];
Arrays.fill(d, INF);
boolean[] vis = new boolean[n];

d[0] = 0;
while (true) {
int t = -1;
for (int i = 0; i < n; i++) {
if (!vis[i] && (t == -1 || d[t] > d[i])) {
t = i;
}
}

if (t == n - 1 || d[t] == INF) {
break;
}
vis[t] = true;

for (int i = 0; i < n; i++) {
d[i] = Math.min(d[i], d[t] + g[t][i]);
}
}

return d[n - 1] == INF ? -1 : d[n - 1];
}

实现二:优先队列优化

  • 时间复杂度为 \(O(m\log{m})\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
private static final int INF = (int) 1e9;

private static int dijkstra(int n, int[][] edges) {
List<int[]>[] g = new List[n];
Arrays.setAll(g, k -> new ArrayList<>());
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
g[u].add(new int[]{v, w});
}

int[] d = new int[n];
Arrays.fill(d, INF);
boolean[] vis = new boolean[n];
Queue<int[]> q = new PriorityQueue<>((a, b) -> a[1] - b[1]);

d[0] = 0;
q.offer(new int[]{0, 0});
while (!q.isEmpty()) {
int u = q.poll()[0];
if (u == n - 1) break;
if (vis[u]) continue;
vis[u] = true;
for (int[] t : g[u]) {
int v = t[0], w = t[1];
if (d[v] > d[u] + w) {
d[v] = d[u] + w;
q.offer(new int[]{v, d[v]});
}
}
}

return d[n - 1] == INF ? -1 : d[n - 1];
}

Bellman-Ford

  • 时间复杂度为 \(O(nm)\)。
  • 使用场景:解决任意边权的单源最短路问题;判断是否存在负环;解决有边数限制的单源最短路问题。

实现一:朴素版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
private static final int INF = (int) 1e9;

private static int bellmanFord(int n, int[][] edges) {
int[] d = new int[n];
Arrays.fill(d, INF);

d[0] = 0;
for (int i = 0; i < n; i++) {
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
d[v] = Math.min(d[v], d[u] + w);
}
}

// d[n - 1] == INF 时,最短路不存在
return d[n - 1];
}

实现二:队列优化(不能存在负环)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
private static final int INF = (int) 1e9;

private static int spfa(int n, int[][] edges) {
List<int[]>[] g = new List[n];
Arrays.setAll(g, k -> new ArrayList<>());
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
g[u].add(new int[]{v, w});
}

int[] d = new int[n];
Arrays.fill(d, INF);
Queue<Integer> q = new ArrayDeque<>();
boolean[] on = new boolean[n];

d[0] = 0;
q.offer(0);
on[0] = true;
while (!q.isEmpty()) {
int u = q.poll();
on[u] = false;
for (int[] t : g[u]) {
int v = t[0], w = t[1];
if (d[v] > d[u] + w) {
d[v] = d[u] + w;
if (!on[v]) {
q.offer(v);
on[v] = true;
}
}
}
}

// d[n - 1] == INF 时,最短路不存在
return d[n - 1];
}

Floyd-Warshall

  • 时间复杂度为 \(O(n^{3})\)。
  • 使用场景:解决任意边权的多源最短路问题。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
private static final int INF = (int) 1e9;

private static int[][] floyd(int n, int[][] edges) {
int[][] dp = new int[n][n];
for (int i = 0; i < n; i++) {
Arrays.fill(dp[i], INF);
dp[i][i] = 0;
}
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
dp[u][v] = Math.min(dp[u][v], w);
}

for (int k = 0; k < n; k++) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (dp[i][k] != INF && dp[k][j] != INF) {
dp[i][j] = Math.min(dp[i][j], dp[i][k] + dp[k][j]);
}
}
}
}
return dp;
}

最近公共祖先

例题

倍增

  • 预处理时间复杂度为 \(O(n\log{n})\),查询时间复杂度为 \(O(\log{n})\)。
  • 原理:\(f[i][j]\) 表示节点 \(j\) 的第 \(2^{i}\) 个祖先,当利用倍增得到 \(f\) 时,对于任意两个节点 \(x,y\),先将较深的节点向上跳到相同深度,然后两个节点贪心的向上跳到 \(\operatorname{lca}\) 下方距离它最近的节点,最后得到的节点就是 \(\operatorname{lca}\) 的直接子节点。(在进行倍增时,根节点的父节点可以是任何值,因为该值不会影响算法的正确性)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
private static void dfs(int x, int fa, List<Integer>[] g, int[][] f, int[] d) {
f[0][x] = fa;
for (int i = 1; 1 << i <= d[x]; i++) {
f[i][x] = f[i - 1][f[i - 1][x]];
}
for (int y : g[x]) {
if (y != fa) {
d[y] = d[x] + 1;
dfs(y, x, g, f, d);
}
}
}

private static int lca(int x, int y, int[][] f, int[] d) {
if (d[x] > d[y]) {
int t = x; x = y; y = t;
}

int diff = d[y] - d[x];
for (int i = 0; i < 31; i++) {
if ((diff >> i & 1) == 1) {
y = f[i][y];
}
}

if (x != y) {
for (int i = 30; i >= 0; i--) {
if (f[i][x] != f[i][y]) {
x = f[i][x];
y = f[i][y];
}
}
x = f[0][x];
}
return x;
}

Tarjan

  • 离线查询算法,时间复杂度为 \(O((n+m)\log{n})\)。更精确的复杂度分析可以使用反阿克曼函数。
  • 原理:每当处理完一个子树,就将该子树的根节点和其父节点合并,特别注意合并的方向是 \(f[y]=x\)。然后我们会遍历包含当前节点 \(x\) 的查询,如果另一个节点 \(y\) 访问过,则 \(\operatorname{lca}(x,y)=\operatorname{find}(y)\)。至于为什么是这样,可以通过分类讨论得到。注意 \(q\) 需要像无向图一样,为单个查询存储双向边。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
private static void tarjan(int x, List<Integer>[] g, boolean[] vis, UnionFind uf, List<int[]>[] q, int[] ans) {
vis[x] = true;
for (int y : g[x]) {
if (!vis[y]) {
tarjan(y, g, vis, uf, q, ans);
uf.union(x, y); // 注意 f[y] = x
}
}

for (int[] t : q[x]) {
int y = t[0], i = t[1];
if (vis[y]) {
ans[i] = uf.find(y);
}
}
}

树链剖分

  • 预处理时间复杂度为 \(O(n)\),查询时间复杂度为 \(O(\log{n})\)。
  • 原理:将树划分为若干重链,树中的每条路径不会包含超过 \(\log{n}\) 条不同的重链,所以查询的时间复杂度为 \(O(\log{n})\)。第一次 DFS 得到每个节点的父节点,深度,以及根据子树大小得到每个节点的重子节点。第二次 DFS 通过优先遍历重子节点,再遍历轻子节点,从而得到每个节点所在重链的头节点。然后就可以进行查询,通过比较 \(x,y\) 所在重链的头节点,来向上跳跃,最终得到 \(\operatorname{lca}\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
private static void dfs1(int x, int fa, List<Integer>[] g, int[] f, int[] d, int[] s, int[] h) {
f[x] = fa;
s[x] = 1; h[x] = -1;
for (int y : g[x]) {
if (y != fa) {
d[y] = d[x] + 1;
dfs1(y, x, g, f, d, s, h);
s[x] += s[y];
if (h[x] == -1 || s[h[x]] < s[y]) {
h[x] = y;
}
}
}
}

private static void dfs2(int x, int head, List<Integer>[] g, int[] f, int[] h, int[] t) {
t[x] = head;
if (h[x] == -1) {
return;
}
dfs2(h[x], head, g, f, h, t);
for (int y : g[x]) {
if (y != f[x] && y != h[x]) {
dfs2(y, y, g, f, h, t);
}
}
}

private static int lca(int x, int y, int[] f, int[] d, int[] t) {
while (t[x] != t[y]) {
if (d[t[x]] > d[t[y]]) {
x = f[t[x]];
} else {
y = f[t[y]];
}
}
return d[x] < d[y] ? x : y;
}

强连通分量

例题

Tarjan

  • 时间复杂度为 \(O(n+m)\)。
  • 原理:\(dfn[x]\) 表示节点 \(x\) 的 DFS 编号;\(low[x]\) 表示节点 \(x\) 能够到达的节点的最小的 DFS 编号。我们将图看作一棵树,并定义四种边,那么强连通分量的根节点就是该分量中第一个被遍历到的节点,满足 \(dfn[x]=low[x]\),所以,过程很复杂,难以描述,直接看 wiki 吧。(注意使用的时候,将 \(dfn\) 初始化为 \(-1\),并且对所有节点调用该算法前,需要判断 \(dfn=-1\) 是否成立)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
private static int dfnCnt, sccCnt;

private static void tarjan(int x, List<Integer>[] g, int[] dfn, int[] low, Deque<Integer> stk, boolean[] on, int[] scc, int[] size) {
dfn[x] = low[x] = dfnCnt++;
stk.push(x);
on[x] = true;

for (int y : g[x]) {
if (dfn[y] == -1) {
tarjan(y, g, dfn, low, stk, on, scc, size);
low[x] = Math.min(low[x], low[y]);
} else if (on[y]) {
low[x] = Math.min(low[x], dfn[y]);
}
}

if (dfn[x] == low[x]) {
for (int y = -1; y != x; ) {
y = stk.pop();
on[y] = false;
scc[y] = sccCnt;
size[sccCnt]++;
}
sccCnt++;
}
}

Codeforces Round 907 (Div. 2)

Sorting with Twos

因为每次只能操作区间 \([1,2^{m}]\),所以 \([2^{m}+1,2^{m+1}]\) 内的所有数是同时进行操作的,它们需要满足非递减的性质,最后不要忘记结尾不能操作的数也需要满足条件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public static void solve() {
int n = io.nextInt();
int[] a = new int[n + 1];
for (int i = 1; i <= n; i++) {
a[i] = io.nextInt();
}

for (int i = 1; 1 << i <= n; i++) {
int j = Math.min(1 << (i + 1), n);
for (int k = (1 << i) + 1; k < j; k++) {
if (a[k] > a[k + 1]) {
io.println("NO");
return;
}
}
}
io.println("YES");
}

Deja Vu

如果一个数能够被 \(2^{i}\) 整除,那么操作之后,它只能被所有小于等于 \(2^{i-1}\) 的二的幂整除。所以预处理所有修改,只保留满足递减顺序的修改,然后模拟即可。或者也可以直接修改,不用预处理,在修改之前判断一下是否比上次小就行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
public static void solve() {
int n = io.nextInt(), q = io.nextInt();
int[] a = new int[n];
for (int i = 0; i < n; i++) {
a[i] = io.nextInt();
}

int mask = 0;
for (int i = 0; i < q; i++) {
int x = io.nextInt();
if (mask == 0 || (1 << x) <= (mask & -mask)) {
mask |= 1 << x;
}
}

for (int i = 0; i < n; i++) {
for (int j = 30; j >= 0; j--) {
if ((mask >> j & 1) == 1 && a[i] % (1 << j) == 0) {
a[i] += 1 << (j - 1);
}
}
}

for (int i = 0; i < n; i++) {
io.print(a[i] + " ");
}
io.println();
}

Smilo and Monsters

比赛时我是排序 + 相向双指针模拟的,先干前面的怪物,如果计数和最后一个的怪物群数量相等,则使用终极攻击,比较麻烦的是双指针到达同一个位置时,需要特判一些情况。然后下面的解法,很简洁啊。似乎总是可以使用普通攻击干掉怪物总数的一半向上取整,并且使用终极攻击干掉总数的一半向下取整。然后排序数组并倒序遍历,使得一次终极攻击干掉尽可能多的怪物,这样就得到最少攻击次数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public static void solve() {
int n = io.nextInt();
int[] a = new int[n];
long sum = 0L;
for (int i = 0; i < n; i++) {
a[i] = io.nextInt();
sum += a[i];
}
Arrays.sort(a);
long ans = (sum + 1) / 2;
sum /= 2;
for (int i = n - 1; i >= 0 && sum > 0; i--) {
sum -= a[i];
ans++;
}
io.println(ans);
}

Suspicious logarithms

\(f(x)\) 表示 \(x\) 的二进制表示中最高位的 \(1\) 所在的位数 \(y\),而 \(g(x)\) 表示满足 \(y^{z}<= x\) 条件的最大的 \(z\)。可以发现如果 \(y=2,x=10^{18}\),则 \(z=59\)。我们可以枚举所有 \(y\in[2,59]\),对于特定的 \(y\),枚举不同的 \(z\) 覆盖的区间范围。得到各个区间范围内所有数的 \(z\) 值,我们就可以在 \(O(\log{(r-l+1)})\) 的时间复杂度内执行查询。为了避免乘法溢出,在进行比较时需要使用除法。其他人代码有直接使用 \(\log\) 的,也比较简单啊,我还以为很麻烦,结果溢出没想到换除法。当然也可以维护前缀和,然后二分区间位置来进行查询。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
private static final int MOD = (int) 1e9 + 7;
private static final List<long[]>[] list = new List[60];

static {
Arrays.setAll(list, k -> new ArrayList<>());

for (int f = 2; f < 60; f++) {
long l = 1L << f, r = (1L << f + 1) - 1;

long k = f, g = 1;
while (k <= l / f) {
k *= f;
g++;
}

for (; l <= r; l = k + 1, g++) {
k = k <= r / f ? k * f - 1 : r;
list[f].add(new long[]{l, k, g});
}
}
}

public static void solve() {
long ans = 0L;
long l = io.nextLong(), r = io.nextLong();
int i = 63 - Long.numberOfLeadingZeros(l);
int j = 63 - Long.numberOfLeadingZeros(r);
for (; i <= j; i++) {
for (long[] t : list[i]) {
ans = (ans + (Math.max(0, Math.min(t[1], r) - Math.max(t[0], l) + 1)) * t[2]) % MOD;
}
}
io.println(ans);
}

A Growing Tree

每个节点的编号是添加该节点时树的大小,因为修改操作不会影响还未添加到树上的节点,所以我们对每个修改操作添加一个编号(时间),表示修改所影响的范围。我们可以使用单点修改、区间查询的树状数组维护修改操作的编号,然后按照 DFS 序遍历树,每当遍历到一个节点,使用树状数组进行单点修改,因为遍历是 DFS 序,所以当前节点的祖先节点已经进行过修改操作,那么当前节点的答案就是所有大于等于该节点编号的修改操作之和。

那么有没有可能该答案会包含其他满足编号大于当前节点的非祖先节点的修改操作呢,不会包含,因为遍历是 DFS 序,DFS 返回时会取消对节点的修改操作,所以每当遍历到一个节点,修改操作只会包含其祖先节点的修改操作。特别注意,数组开 \(q+2\) 的大小,因为初始时有一个根节点,所以节点数量最多为 \(q+1\),然后编号从 \(1\) 开始。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
public static void solve() {
int q = io.nextInt(), sz = 1;
List<int[]>[] queries = new List[q + 2];
Arrays.setAll(queries, k -> new ArrayList<>());
List<Integer>[] g = new List[q + 2];
Arrays.setAll(g, k -> new ArrayList<>());
for (int i = 0; i < q; i++) {
int t = io.nextInt();
if (t == 1) {
int v = io.nextInt();
g[v].add(++sz);
} else {
int v = io.nextInt(), x = io.nextInt();
queries[v].add(new int[]{sz, x});
}
}
var bit = new BIT(sz);
long[] ans = new long[sz + 1];
dfs(1, sz, g, queries, bit, ans);
for (int i = 1; i <= sz; i++) {
io.print(ans[i] + " ");
}
io.println();
}

private static void dfs(int x, int sz, List<Integer>[] g, List<int[]>[] queries, BIT bit, long[] ans) {
for (int[] q : queries[x]) {
bit.add(q[0], q[1]);
}
ans[x] = bit.get(x, sz);
for (int y : g[x]) {
dfs(y, sz, g, queries, bit, ans);
}
for (int[] q : queries[x]) {
bit.add(q[0], -q[1]);
}
}

第 369 场力扣周赛

找出数组中的 K-or 值

模拟。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Solution {
public int findKOr(int[] nums, int k) {
int ans = 0;
for (int i = 0; i < 31; i++) {
int cnt = 0;
for (int x : nums) {
cnt += x >> i & 1;
}
if (cnt >= k) {
ans |= 1 << i;
}
}
return ans;
}
}

数组的最小相等和

分类讨论。所有 \(0\) 都必须被替换为正整数,那么首先将所有 \(0\) 替换为 \(1\)。如果两个数组中都有 \(0\),则此时得到的最大的数组和就应该是答案,因为较小的一方总是可以使用更大的正整数替换 \(0\),使得两个数组的元素和相等。如果某个数组的和不等于最大和,并且数组中不包含 \(0\),那么就无法使两个数组的元素和相等。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Solution {
public long minSum(int[] nums1, int[] nums2) {
int cnt1 = 0, cnt2 = 0;
long sum1 = 0L, sum2 = 0L;
for (int x : nums1) {
if (x == 0) cnt1++;
sum1 += x;
}
for (int x : nums2) {
if (x == 0) cnt2++;
sum2 += x;
}
long max = Math.max(sum1 + cnt1, sum2 + cnt2);
if (sum1 != max && cnt1 == 0) return -1;
if (sum2 != max && cnt2 == 0) return -1;
return max;
}
}

使数组变美的最小增量运算数

按照灵神的题解,每个位置的状态就是它右边有多少个数小于 \(k\),加上这个维度就可以做记忆化搜索,然后转递推真的很妙。题解区还有其他的状态定义方式,可以看看。代码就贴灵神的。(这题感觉很不错,没有见过的类型)

1
2
3
4
5
6
7
8
9
10
11
12
class Solution {
public long minIncrementOperations(int[] nums, int k) {
long f0 = 0, f1 = 0, f2 = 0;
for (int x : nums) {
long inc = f0 + Math.max(k - x, 0);
f0 = Math.min(inc, f1);
f1 = Math.min(inc, f2);
f2 = inc;
}
return f0;
}
}

收集所有金币可获得的最大积分

当我们遍历到某个节点时,它的状态就是需要除以多少次 \(2\),由数据范围可知每个节点最多有 \(15\) 个状态。我们可以从子问题的最优解推出原问题的最优解,并且子问题可以独立求解,符合最优子结构;如果当前节点处于某个状态,它可能是由不同的路径转移得到的,即存在重叠子问题。所以我们可以使用树型 DP 求解该问题,列出如下状态转移方程:

$$ dp[x][i]=\max(\sum_{y}{dp[y][i]}+(coins[x]>>i)-k,\sum_{y}{dp[y][i + 1]}+(coins[x]>>(i+1))) $$

其中 \(dp[i][j]\) 表示到达节点 \(i\),需要除以 \(2^{j}\),该状态下以节点 \(i\) 为根的子树能够得到的最大积分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class Solution {
public int maximumPoints(int[][] edges, int[] coins, int k) {
int n = coins.length;
List<Integer>[] g = new List[n];
Arrays.setAll(g, t -> new ArrayList<>());
for (var e : edges) {
int u = e[0], v = e[1];
g[u].add(v);
g[v].add(u);
}
return dfs(0, -1, g, coins, k)[0];
}

private int[] dfs(int x, int fa, List<Integer>[] g, int[] coins, int k) {
int[] sum1 = new int[15];
int[] sum2 = new int[15];

for (int y : g[x]) {
if (y == fa) continue;
int[] t = dfs(y, x, g, coins, k);
for (int i = 0; i < 14; i++) {
sum1[i] += t[i];
sum2[i] += t[i + 1];
}
}

for (int i = 0; i < 14; i++) {
sum1[i] = Math.max(sum1[i] + (coins[x] >> i) - k, sum2[i] + (coins[x] >> (i + 1)));
}
return sum1;
}
}