Project #3 - Query Execution

项目准备

项目地址:Project #3 - Query Execution

准备工作:阅读 Chapter 15 16 22,学习 Lecture #10 #11 #12 #13 #14,以及阅读课堂笔记。

项目结构

通过查看 sqllogictest.cpp,可以知道 SQL 语句的整个执行流程。首先调用 SQLLogicTestParser::Parse 将测试文件解析为多个测试记录,然后根据记录的类型分别处理。目前我们主要关注查询语句,只需查看 BustubInstance::ExecuteSqlTxn 函数的代码。如项目介绍描述的那样,代码分别执行 Binder,Planner,Optimize,ExecutionEngine。然后,本来想详细分析一下整个流程,但是由于时间原因,以及项目确实比较复杂,所以暂时搁置。

Task #1 - Access Method Executors

实现

① 遇到第一个问题,如何在 SeqScanExecutor 中遍历表,可以发现 exec_ctx 成员所属的类 ExecutorContext 中有一个 GetCatalog 方法,只要拿到 Catalog 就可以根据 plan_ 中的信息拿到 TableHeap 的迭代器 TableIterator。然后第二个问题就是如何存储迭代器,TableIterator 是不可复制的,我们可以使用 unique_ptr 来存储迭代器,并使用 make_unique 初始化。(注意,不能在构造函数初始化,一定要在 Init 函数中初始化,不然多次全表扫描会出问题!)

② 实现 Insert 时报错 “The executor will produce a single tuple of integer type as the output, indicating how many rows have been inserted into the table”,并且可以看到 Next 函数的注释中表示 “tuple The integer tuple indicating the number of rows inserted into the table”。说实话有点难以理解,我一开始以为每次调用 Next 会像迭代器模式一样,只执行一次插入,但是这样实现就会报上面的错误。然后通过查看 Discord 的讨论,发现是一次性插入所有记录,因为只要返回 true 就会打印插入的行数,返回 false 就不会打印。当插入零行时,还必须打印一个零,这说明,Next 必定要先返回 true,再返回 false。并且在构造 tuple 时需要使用 BIGINT 类型,不然会报其他错误(明明注释说的是 INTEGER 额)。

③ 在 Insert 的同时需要更新索引,一开始我是直接用普通的 tuple 作为 InsertEntry 的参数,结果在测试 p3.05-index-scan.slt 时报 stack buffer overflow 错误。通过 Debug 发现,在 InsertEntry 时会调用 GenericKey 类的 SetFromKey 函数,该函数会将 tuple 的数据拷贝到该类的 data_ 成员中,作为索引的 key 使用。所以传入的 tuple 必须只包含 key,那么如何确定 tuple 中的哪个数据是 key 呢。可以发现 Tuple 类中有 KeyFromTuple 函数,它的会生成只包含 keytuple,因为需要的索引的 key,那么该函数必定需要传入和索引相关的模式,以及 key 所在列的下标,这些信息可以在 IndexInfo 中找到。(之前我有点迷糊,当成 MySQL 默认使用主键索引了,BusTub 使用的是 TableHeap,也就是说表默认是没有索引的)

④ 实现时不要使用 GetTableOid 函数,因为线上测试的函数名是 TableOid,可能是因为我 fork 的版本太新了,仓库的代码和测试代码不一样,所以只能直接使用 table_oid_ 成员。

⑤ 实现 update 时要注意,在创建新 tuple 时,使用的是 child_executor_->GetOutputSchema(),而不是 GetOutputSchema()

⑥ 实现 index_scan 时,会使用到 b_plus_tree_index.h 中定义的别名,如 BPlusTreeIndexIteratorForTwoIntegerColumn

⑧ 在 IndexScan 的提示中有这么一句话,“do not emit tuples that are deleted”,但是当从表中删除 tuple 时,也会从索引中删除对应的 key,所以应该不会遍历到已经删除的 key 才对,也就是说此时应该不用特判 TupleMeta 中的 is_deleted_ 成员。

⑨ 测试 p3.06-empty-table.slt 时,遇到 B+Tree 迭代器实现问题。当 B+Tree 的为 empty 时,获取迭代器我原来是抛出异常,现在改为返回一个默认构造的迭代器。

补充

① 当没有显示声明复制/移动构造函数或复制/移动运算符,以及析构函数时,编译器才会隐式生成这些函数(其他更复杂的情况可以查看 cppreference.com)。

② 创建 TupleMeta 时,会将 insertion_txndeletion_txn_ 都初始化为 INVALID_TXN_ID,提示表示这些成员会在以后切换到 MVCC 存储时使用,有点遗憾没能体验一下。

vectorreserve 只会影响 capacity 的大小,而不会影响 size讨论在此

④ 重载前置和后置 ++ 的区别,前置 ++ 的重载声明为 operator++(),后置 ++ 的重载声明为 operator++(int)

⑤ 为什么应该将移动构造声明为 noexcept,可以阅读 Friendly reminder to mark your move constructors noexcept

Task #2 - Aggregation & Join Executors

实现

① 一开始实现真摸不着头脑,AggregationPlanNode 里面怎么这么多东西。group_bys 是指 GROUP BY 时对列的求值表达式,aggregates 是指使用聚合函数时对列的求值表达式,agg_types 是指聚合函数的类型。例如:GROUP BY DAY(col)MIN(col1 + col2)。我们使用 InsertCombine 函数向哈希表插入值,参数可以使用 MakeAggregateKeyMakeAggregateValue 函数获得。

② 根据项目介绍,AggregationExecutor::Next 返回的 tuple 应该包含 keyvalue(我没看到,找错好难)。特别需要注意,当哈希表为空时,应该返回什么:如果是对某列进行 GROUP BY 操作,那么就返回 false,因为有个测试用例有注释 no groups, no output;否则,返回 true,并且 tuple 存储聚合函数的默认值。(可以通过判断 key 模式的列数是否为零,或者 value 模式的列数是否等于 plan_ 输出模式的列数,来判断当前是否对某列进行 GROUP BY 操作)

③ 实现 NestedLoopJoinExecutor:外层循环遍历左表,内层循环遍历右表,只有当右表遍历完,才会获取下一个左表中的元组。但是,因为每找到一个匹配就会返回,所以我们应该将左表的元组作为数据成员,并且添加一个标志表示右表是否遍历完。每当右表遍历完成,都需要重置标志,获取左表中的下一个元组,并且重新 Init 右表。我们调用 EvaluateJoin 判断元组是否匹配,如果匹配,就将两个元组合并为一个元组。特别注意,如果当前是左连接,并且左元组没有匹配任何右元组,仍然需要返回一个为右元组填充 null 值的合并元组。比较迷惑的是怎么表示 null,我的想法是根据列类型获取对应的 null 值,但是找不到这样的函数,所以我就直接返回 BUSTUB_INT32_NULL。突然看到聚合执行器里用到 ValueFactory::GetNullValueByType 函数,太久没写项目给忘了。我还遇到一个 BUG,调试半天,发现我没有在 Init 函数中初始化 SeqScanExecutor 的迭代器,导致重复调用 Init 时不会重置迭代器。

④ 实现 HashJoin:根据提示我们可以参考 SimpleAggregationHashTable 的实现建立一个哈希表,我们创建一个 JoinKey 类作为键,然后创建一个 hash<bustub::JoinKey> 类,直接复制 aggregation_plan.h 中的代码改个名字就行(不然 C++ 真不熟,又要搞半天)。在哈希表中,将 vector<Tuple> 作为值以处理哈希冲突。搞定哈希的方式之后,我们可以像 aggregation_executor.h 一样添加两个辅助函数 MakeLeftJoinKeyMakeRightJoinKey。然后直接在 Init 中对左表建立哈希表,在 Next 中遍历右表,类似 NestedLoopJoinExecutor 的实现,只不过此时需要维护更多的数据成员。特别需要注意如何处理左连接,因为我们是将左表建为哈希表,那么在遍历完右表后,还需要处理没有任何匹配的左表中的元组。这可以在匹配时将元组的地址存储在 unordered_set 中,然后在遍历完右表后再遍历一次左表,并检查 unordered_set 来判断是否输出。(之前我是将元组的 RID 存储到集合中作为标识,但是这是错误的,因为左表可能是临时表,其中元组的 RID 是无效的内容;我们也可以为右表建立哈希表而不是左表,这样对于左连接来说,更好处理)

⑤ 实现 Optimizing NestedLoopJoin to HashJoin:非常的神奇,参考 nlj_as_index_join.cpp 瞎改,感觉代码是一坨,但是竟然没有任何错误,直接通过测试(激动半天)。具体实现的话,一开始我以为传入的参数就是 NestedLoopJoin 计划节点,但是似乎不是,所以我们需要遍历当前计划的子节点,递归的进行优化。之前比较令我迷惑的一点,怎么判断表达式是否是某个类型,我查找很久 API 都没有找到类似的函数,然后想到 Project #0 中好像是直接做 dynamic_cast 转换,如果返回值为 nullptr 就表示类型不匹配,查看 nlj_as_index_join.cpp 发现果然是这样。搞定表达式类型判断之后,就可以根据 ColumnValueExpression::GetTupleIdx 值来交换左右表达式,并返回转换后的节点。

Task #3 - Sort + Limit Executors and Top-N Optimization

Easy!只有两点需要注意:一个是每次调用 Init 时都要初始化所有数据成员,不然下次调用会包含上次调用的数据;第二个是 C++ 的 priority_queue 默认是大顶堆,并且比较器和 Java 中的用法完全相反。

Optional Leaderboard Tasks

① 初次提交。

② 之后优化。

Rank Submission Name Q1 Q2 Q3 Time
123 ALEX 740 30000 4839 4754

测试结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#!/bin/bash
make sqllogictest -j$(nproc)

./bin/bustub-sqllogictest ../test/sql/p3.00-primer.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.01-seqscan.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.02-insert.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.04-delete.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.03-update.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.05-index-scan.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.06-empty-table.slt --verbose

./bin/bustub-sqllogictest ../test/sql/p3.07-simple-agg.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.08-group-agg-1.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.09-group-agg-2.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.10-simple-join.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.11-multi-way-join.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.12-repeat-execute.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.14-hash-join.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.15-multi-way-hash-join.slt --verbose

./bin/bustub-sqllogictest ../test/sql/p3.16-sort-limit.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.17-topn.slt --verbose

./bin/bustub-sqllogictest ../test/sql/p3.18-integration-1.slt --verbose
./bin/bustub-sqllogictest ../test/sql/p3.19-integration-2.slt --verbose

make format
make check-lint
make check-clang-tidy-p3
make submit-p3

项目小结

项目难度主要在项目理解上,常常是不理解某些变量的实际含义,或者知道该怎么做,却找不到对应的 API,或者对返回值理解有错误,而函数文档也不清晰。最后,看到实现的代码能够执行各种 SQL 语句,感觉还是很不错的。

图论

本文内容参考《算法》,《算法导论》,OI Wiki

拓扑排序

例题

实现

  • 时间复杂度为 \(O(n+m)\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
private static int[] topologicalSort(int n, int[][] edges) {
int[] in = new int[n];
List<Integer>[] g = new List[n];
Arrays.setAll(g, k -> new ArrayList<>());
for (var e : edges) {
int u = e[0], v = e[1];
g[u].add(v);
in[v]++;
}

Queue<Integer> q = new ArrayDeque<>();
for (int i = 0; i < n; i++) {
if (in[i] == 0) {
q.offer(i);
}
}

int idx = 0;
int[] res = new int[n];
while (!q.isEmpty()) {
int x = q.poll();
res[idx++] = x;
for (int y : g[x]) {
if (--in[y] == 0) {
q.offer(y);
}
}
}

// 拓扑排序不存在
assert idx == n;

return res;
}

最小生成树

例题

Prim

实现一:朴素版本

  • 时间复杂度为 \(O(n^{2})\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
private static int prim(int n, int[][] edges) {
int[][] g = new int[n][n];
for (int i = 0; i < n; i++) {
Arrays.fill(g[i], Integer.MAX_VALUE);
g[i][i] = 0;
}
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
if (g[u][v] > w) {
g[u][v] = g[v][u] = w;
}
}

int[] d = new int[n];
Arrays.fill(d, Integer.MAX_VALUE);
boolean[] vis = new boolean[n];

int res = 0;
d[0] = 0;
for (int i = 0; i < n; i++) {
int t = -1;
for (int j = 0; j < n; j++) {
if (!vis[j] && (t == -1 || d[t] > d[j])) {
t = j;
}
}

// 不是连通图,最小生成树不存在
assert d[t] != Integer.MAX_VALUE;

vis[t] = true;
res += d[t];

for (int j = 0; j < n; j++) {
d[j] = Math.min(d[j], g[t][j]);
}
}

return res;
}

实现二:优先队列优化

  • 时间复杂度为 \(O(m\log{m})\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
private static int prim(int n, int[][] edges) {
List<int[]>[] g = new List[n];
Arrays.setAll(g, k -> new ArrayList<>());
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
g[u].add(new int[]{v, w});
g[v].add(new int[]{u, w});
}

int[] d = new int[n];
Arrays.fill(d, Integer.MAX_VALUE);
boolean[] vis = new boolean[n];
Queue<int[]> q = new PriorityQueue<>((a, b) -> a[1] - b[1]);

int res = 0, cnt = 0;
d[0] = 0;
q.offer(new int[]{0, 0});
while (!q.isEmpty()) {
int u = q.poll()[0];
if (vis[u]) continue;
vis[u] = true;
res += d[u];
if (++cnt == n) break;
for (int[] t : g[u]) {
int v = t[0], w = t[1];
if (!vis[v] && d[v] > w) {
d[v] = w;
q.offer(new int[]{v, d[v]});
}
}
}

// 不是连通图,最小生成树不存在
assert cnt == n;

return res;
}

Kruskal

  • 时间复杂度为 \(O(m\log{m})\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
private static int kruskal(int n, int[][] edges) {
Arrays.sort(edges, (a, b) -> a[2] - b[2]);

int cnt = 1, res = 0;
UnionFind uf = new UnionFind(n);
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
if (uf.connected(u, v)) continue;
uf.union(u, v);
res += w;
if (++cnt == n) break;
}

// 不是连通图,最小生成树不存在
assert cnt == n;

return res;
}

最短路

例题

Dijkstra

  • 使用场景:解决边权非负的单源最短路问题。

实现一:朴素版本

  • 时间复杂度为 \(O(n^{2})\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
private static final int INF = (int) 1e9;

private static int dijkstra(int n, int[][] edges) {
int[][] g = new int[n][n];
for (int i = 0; i < n; i++) {
Arrays.fill(g[i], INF);
g[i][i] = 0;
}
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
g[u][v] = Math.min(g[u][v], w);
}

int[] d = new int[n];
Arrays.fill(d, INF);
boolean[] vis = new boolean[n];

d[0] = 0;
while (true) {
int t = -1;
for (int i = 0; i < n; i++) {
if (!vis[i] && (t == -1 || d[t] > d[i])) {
t = i;
}
}

if (t == n - 1 || d[t] == INF) {
break;
}
vis[t] = true;

for (int i = 0; i < n; i++) {
d[i] = Math.min(d[i], d[t] + g[t][i]);
}
}

return d[n - 1] == INF ? -1 : d[n - 1];
}

实现二:优先队列优化

  • 时间复杂度为 \(O(m\log{m})\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
private static final int INF = (int) 1e9;

private static int dijkstra(int n, int[][] edges) {
List<int[]>[] g = new List[n];
Arrays.setAll(g, k -> new ArrayList<>());
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
g[u].add(new int[]{v, w});
}

int[] d = new int[n];
Arrays.fill(d, INF);
boolean[] vis = new boolean[n];
Queue<int[]> q = new PriorityQueue<>((a, b) -> a[1] - b[1]);

d[0] = 0;
q.offer(new int[]{0, 0});
while (!q.isEmpty()) {
int u = q.poll()[0];
if (u == n - 1) break;
if (vis[u]) continue;
vis[u] = true;
for (int[] t : g[u]) {
int v = t[0], w = t[1];
if (d[v] > d[u] + w) {
d[v] = d[u] + w;
q.offer(new int[]{v, d[v]});
}
}
}

return d[n - 1] == INF ? -1 : d[n - 1];
}

Bellman-Ford

  • 时间复杂度为 \(O(nm)\)。
  • 使用场景:解决任意边权的单源最短路问题;判断是否存在负环;解决有边数限制的单源最短路问题。

实现一:朴素版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
private static final int INF = (int) 1e9;

private static int bellmanFord(int n, int[][] edges) {
int[] d = new int[n];
Arrays.fill(d, INF);

d[0] = 0;
for (int i = 0; i < n; i++) {
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
d[v] = Math.min(d[v], d[u] + w);
}
}

// d[n - 1] == INF 时,最短路不存在
return d[n - 1];
}

实现二:队列优化(不能存在负环)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
private static final int INF = (int) 1e9;

private static int spfa(int n, int[][] edges) {
List<int[]>[] g = new List[n];
Arrays.setAll(g, k -> new ArrayList<>());
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
g[u].add(new int[]{v, w});
}

int[] d = new int[n];
Arrays.fill(d, INF);
Queue<Integer> q = new ArrayDeque<>();
boolean[] on = new boolean[n];

d[0] = 0;
q.offer(0);
on[0] = true;
while (!q.isEmpty()) {
int u = q.poll();
on[u] = false;
for (int[] t : g[u]) {
int v = t[0], w = t[1];
if (d[v] > d[u] + w) {
d[v] = d[u] + w;
if (!on[v]) {
q.offer(v);
on[v] = true;
}
}
}
}

// d[n - 1] == INF 时,最短路不存在
return d[n - 1];
}

Floyd-Warshall

  • 时间复杂度为 \(O(n^{3})\)。
  • 使用场景:解决任意边权的多源最短路问题。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
private static final int INF = (int) 1e9;

private static int[][] floyd(int n, int[][] edges) {
int[][] dp = new int[n][n];
for (int i = 0; i < n; i++) {
Arrays.fill(dp[i], INF);
dp[i][i] = 0;
}
for (var e : edges) {
int u = e[0], v = e[1], w = e[2];
dp[u][v] = Math.min(dp[u][v], w);
}

for (int k = 0; k < n; k++) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (dp[i][k] != INF && dp[k][j] != INF) {
dp[i][j] = Math.min(dp[i][j], dp[i][k] + dp[k][j]);
}
}
}
}
return dp;
}

最近公共祖先

例题

倍增

  • 预处理时间复杂度为 \(O(n\log{n})\),查询时间复杂度为 \(O(\log{n})\)。
  • 原理:\(f[i][j]\) 表示节点 \(j\) 的第 \(2^{i}\) 个祖先,当利用倍增得到 \(f\) 时,对于任意两个节点 \(x,y\),先将较深的节点向上跳到相同深度,然后两个节点贪心的向上跳到 \(\operatorname{lca}\) 下方距离它最近的节点,最后得到的节点就是 \(\operatorname{lca}\) 的直接子节点。(在进行倍增时,根节点的父节点可以是任何值,因为该值不会影响算法的正确性)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
private static void dfs(int x, int fa, List<Integer>[] g, int[][] f, int[] d) {
f[0][x] = fa;
for (int i = 1; 1 << i <= d[x]; i++) {
f[i][x] = f[i - 1][f[i - 1][x]];
}
for (int y : g[x]) {
if (y != fa) {
d[y] = d[x] + 1;
dfs(y, x, g, f, d);
}
}
}

private static int lca(int x, int y, int[][] f, int[] d) {
if (d[x] > d[y]) {
int t = x; x = y; y = t;
}

int diff = d[y] - d[x];
for (int i = 0; i < 31; i++) {
if ((diff >> i & 1) == 1) {
y = f[i][y];
}
}

if (x != y) {
for (int i = 30; i >= 0; i--) {
if (f[i][x] != f[i][y]) {
x = f[i][x];
y = f[i][y];
}
}
x = f[0][x];
}
return x;
}

Tarjan

  • 离线查询算法,时间复杂度为 \(O((n+m)\log{n})\)。更精确的复杂度分析可以使用反阿克曼函数。
  • 原理:每当处理完一个子树,就将该子树的根节点和其父节点合并,特别注意合并的方向是 \(f[y]=x\)。然后我们会遍历包含当前节点 \(x\) 的查询,如果另一个节点 \(y\) 访问过,则 \(\operatorname{lca}(x,y)=\operatorname{find}(y)\)。至于为什么是这样,可以通过分类讨论得到。注意 \(q\) 需要像无向图一样,为单个查询存储双向边。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
private static void tarjan(int x, List<Integer>[] g, boolean[] vis, UnionFind uf, List<int[]>[] q, int[] ans) {
vis[x] = true;
for (int y : g[x]) {
if (!vis[y]) {
tarjan(y, g, vis, uf, q, ans);
uf.union(x, y); // 注意 f[y] = x
}
}

for (int[] t : q[x]) {
int y = t[0], i = t[1];
if (vis[y]) {
ans[i] = uf.find(y);
}
}
}

树链剖分

  • 预处理时间复杂度为 \(O(n)\),查询时间复杂度为 \(O(\log{n})\)。
  • 原理:将树划分为若干重链,树中的每条路径不会包含超过 \(\log{n}\) 条不同的重链,所以查询的时间复杂度为 \(O(\log{n})\)。第一次 DFS 得到每个节点的父节点,深度,以及根据子树大小得到每个节点的重子节点。第二次 DFS 通过优先遍历重子节点,再遍历轻子节点,从而得到每个节点所在重链的头节点。然后就可以进行查询,通过比较 \(x,y\) 所在重链的头节点,来向上跳跃,最终得到 \(\operatorname{lca}\)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
private static void dfs1(int x, int fa, List<Integer>[] g, int[] f, int[] d, int[] s, int[] h) {
f[x] = fa;
s[x] = 1; h[x] = -1;
for (int y : g[x]) {
if (y != fa) {
d[y] = d[x] + 1;
dfs1(y, x, g, f, d, s, h);
s[x] += s[y];
if (h[x] == -1 || s[h[x]] < s[y]) {
h[x] = y;
}
}
}
}

private static void dfs2(int x, int head, List<Integer>[] g, int[] f, int[] h, int[] t) {
t[x] = head;
if (h[x] == -1) {
return;
}
dfs2(h[x], head, g, f, h, t);
for (int y : g[x]) {
if (y != f[x] && y != h[x]) {
dfs2(y, y, g, f, h, t);
}
}
}

private static int lca(int x, int y, int[] f, int[] d, int[] t) {
while (t[x] != t[y]) {
if (d[t[x]] > d[t[y]]) {
x = f[t[x]];
} else {
y = f[t[y]];
}
}
return d[x] < d[y] ? x : y;
}

强连通分量

例题

Tarjan

  • 时间复杂度为 \(O(n+m)\)。
  • 原理:\(dfn[x]\) 表示节点 \(x\) 的 DFS 编号;\(low[x]\) 表示节点 \(x\) 能够到达的节点的最小的 DFS 编号。我们将图看作一棵树,并定义四种边,那么强连通分量的根节点就是该分量中第一个被遍历到的节点,满足 \(dfn[x]=low[x]\),所以,过程很复杂,难以描述,直接看 wiki 吧。(注意使用的时候,将 \(dfn\) 初始化为 \(-1\),并且对所有节点调用该算法前,需要判断 \(dfn=-1\) 是否成立)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
private static int dfnCnt, sccCnt;

private static void tarjan(int x, List<Integer>[] g, int[] dfn, int[] low, Deque<Integer> stk, boolean[] on, int[] scc, int[] size) {
dfn[x] = low[x] = dfnCnt++;
stk.push(x);
on[x] = true;

for (int y : g[x]) {
if (dfn[y] == -1) {
tarjan(y, g, dfn, low, stk, on, scc, size);
low[x] = Math.min(low[x], low[y]);
} else if (on[y]) {
low[x] = Math.min(low[x], dfn[y]);
}
}

if (dfn[x] == low[x]) {
for (int y = -1; y != x; ) {
y = stk.pop();
on[y] = false;
scc[y] = sccCnt;
size[sccCnt]++;
}
sccCnt++;
}
}

Codeforces Round 907 (Div. 2)

Sorting with Twos

因为每次只能操作区间 \([1,2^{m}]\),所以 \([2^{m}+1,2^{m+1}]\) 内的所有数是同时进行操作的,它们需要满足非递减的性质,最后不要忘记结尾不能操作的数也需要满足条件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public static void solve() {
int n = io.nextInt();
int[] a = new int[n + 1];
for (int i = 1; i <= n; i++) {
a[i] = io.nextInt();
}

for (int i = 1; 1 << i <= n; i++) {
int j = Math.min(1 << (i + 1), n);
for (int k = (1 << i) + 1; k < j; k++) {
if (a[k] > a[k + 1]) {
io.println("NO");
return;
}
}
}
io.println("YES");
}

Deja Vu

如果一个数能够被 \(2^{i}\) 整除,那么操作之后,它只能被所有小于等于 \(2^{i-1}\) 的二的幂整除。所以预处理所有修改,只保留满足递减顺序的修改,然后模拟即可。或者也可以直接修改,不用预处理,在修改之前判断一下是否比上次小就行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
public static void solve() {
int n = io.nextInt(), q = io.nextInt();
int[] a = new int[n];
for (int i = 0; i < n; i++) {
a[i] = io.nextInt();
}

int mask = 0;
for (int i = 0; i < q; i++) {
int x = io.nextInt();
if (mask == 0 || (1 << x) <= (mask & -mask)) {
mask |= 1 << x;
}
}

for (int i = 0; i < n; i++) {
for (int j = 30; j >= 0; j--) {
if ((mask >> j & 1) == 1 && a[i] % (1 << j) == 0) {
a[i] += 1 << (j - 1);
}
}
}

for (int i = 0; i < n; i++) {
io.print(a[i] + " ");
}
io.println();
}

Smilo and Monsters

比赛时我是排序 + 相向双指针模拟的,先干前面的怪物,如果计数和最后一个的怪物群数量相等,则使用终极攻击,比较麻烦的是双指针到达同一个位置时,需要特判一些情况。然后下面的解法,很简洁啊。似乎总是可以使用普通攻击干掉怪物总数的一半向上取整,并且使用终极攻击干掉总数的一半向下取整。然后排序数组并倒序遍历,使得一次终极攻击干掉尽可能多的怪物,这样就得到最少攻击次数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public static void solve() {
int n = io.nextInt();
int[] a = new int[n];
long sum = 0L;
for (int i = 0; i < n; i++) {
a[i] = io.nextInt();
sum += a[i];
}
Arrays.sort(a);
long ans = (sum + 1) / 2;
sum /= 2;
for (int i = n - 1; i >= 0 && sum > 0; i--) {
sum -= a[i];
ans++;
}
io.println(ans);
}

Suspicious logarithms

\(f(x)\) 表示 \(x\) 的二进制表示中最高位的 \(1\) 所在的位数 \(y\),而 \(g(x)\) 表示满足 \(y^{z}<= x\) 条件的最大的 \(z\)。可以发现如果 \(y=2,x=10^{18}\),则 \(z=59\)。我们可以枚举所有 \(y\in[2,59]\),对于特定的 \(y\),枚举不同的 \(z\) 覆盖的区间范围。得到各个区间范围内所有数的 \(z\) 值,我们就可以在 \(O(\log{(r-l+1)})\) 的时间复杂度内执行查询。为了避免乘法溢出,在进行比较时需要使用除法。其他人代码有直接使用 \(\log\) 的,也比较简单啊,我还以为很麻烦,结果溢出没想到换除法。当然也可以维护前缀和,然后二分区间位置来进行查询。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
private static final int MOD = (int) 1e9 + 7;
private static final List<long[]>[] list = new List[60];

static {
Arrays.setAll(list, k -> new ArrayList<>());

for (int f = 2; f < 60; f++) {
long l = 1L << f, r = (1L << f + 1) - 1;

long k = f, g = 1;
while (k <= l / f) {
k *= f;
g++;
}

for (; l <= r; l = k + 1, g++) {
k = k <= r / f ? k * f - 1 : r;
list[f].add(new long[]{l, k, g});
}
}
}

public static void solve() {
long ans = 0L;
long l = io.nextLong(), r = io.nextLong();
int i = 63 - Long.numberOfLeadingZeros(l);
int j = 63 - Long.numberOfLeadingZeros(r);
for (; i <= j; i++) {
for (long[] t : list[i]) {
ans = (ans + (Math.max(0, Math.min(t[1], r) - Math.max(t[0], l) + 1)) * t[2]) % MOD;
}
}
io.println(ans);
}

A Growing Tree

每个节点的编号是添加该节点时树的大小,因为修改操作不会影响还未添加到树上的节点,所以我们对每个修改操作添加一个编号(时间),表示修改所影响的范围。我们可以使用单点修改、区间查询的树状数组维护修改操作的编号,然后按照 DFS 序遍历树,每当遍历到一个节点,使用树状数组进行单点修改,因为遍历是 DFS 序,所以当前节点的祖先节点已经进行过修改操作,那么当前节点的答案就是所有大于等于该节点编号的修改操作之和。

那么有没有可能该答案会包含其他满足编号大于当前节点的非祖先节点的修改操作呢,不会包含,因为遍历是 DFS 序,DFS 返回时会取消对节点的修改操作,所以每当遍历到一个节点,修改操作只会包含其祖先节点的修改操作。特别注意,数组开 \(q+2\) 的大小,因为初始时有一个根节点,所以节点数量最多为 \(q+1\),然后编号从 \(1\) 开始。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
public static void solve() {
int q = io.nextInt(), sz = 1;
List<int[]>[] queries = new List[q + 2];
Arrays.setAll(queries, k -> new ArrayList<>());
List<Integer>[] g = new List[q + 2];
Arrays.setAll(g, k -> new ArrayList<>());
for (int i = 0; i < q; i++) {
int t = io.nextInt();
if (t == 1) {
int v = io.nextInt();
g[v].add(++sz);
} else {
int v = io.nextInt(), x = io.nextInt();
queries[v].add(new int[]{sz, x});
}
}
var bit = new BIT(sz);
long[] ans = new long[sz + 1];
dfs(1, sz, g, queries, bit, ans);
for (int i = 1; i <= sz; i++) {
io.print(ans[i] + " ");
}
io.println();
}

private static void dfs(int x, int sz, List<Integer>[] g, List<int[]>[] queries, BIT bit, long[] ans) {
for (int[] q : queries[x]) {
bit.add(q[0], q[1]);
}
ans[x] = bit.get(x, sz);
for (int y : g[x]) {
dfs(y, sz, g, queries, bit, ans);
}
for (int[] q : queries[x]) {
bit.add(q[0], -q[1]);
}
}