-学习笔记-虚树- 虚树学习笔记

今天是个好日子!

最近学了非常酷炫的虚树,感觉这个算法挺实用的。

虚树的做法就是把有用的点及其它们的LCA拎出来,无关的点给丢掉,这样就大大优化了树形dp的复杂度。

设有效节点为 $k$ 个,则建虚树的时间复杂度为 $O(k \log n)$,在虚树上进行树形dp的时间复杂度为 $O(k)​$。

下面的代码展示了如何建虚树:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void build(int x) {
if (!top) {
st[++top] = x;
return;
}
int lca = LCA(st[top], x);
if (lca == st[top]) return;
while (top && id[st[top - 1]] >= id[lca]) {
v[st[top - 1]].push_back(st[top]);
top--;
}
if (lca != st[top]) {
v[lca].push_back(st[top]);
st[top] = lca;
}
st[++top] = x;
}

上面的代码展示了如何用栈来维护这些点。

然后就是一道例题:

[SDOI2011]消耗战

这题数据范围较大,但是k较小,$O(nm)$ 的DP肯定过不去,所以考虑使用虚树优化,将复杂度降低为 $O(k)$。

献上AC代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define MAXN 250010
#define int long long
struct Edge {
int v, nx, w;
}e[MAXN << 1];
int head[MAXN], ecnt, n, m, x, y, z, id[MAXN], dep[MAXN], son[MAXN], k;
int si[MAXN], st[MAXN], tp, top[MAXN], fa[MAXN], tim, mi[MAXN], a[MAXN];
std::vector <int> v[MAXN];
bool cmp(int a, int b) {
return id[a] < id[b];
}
void add(int f, int t, int w) {
e[++ecnt] = (Edge) {t, head[f], w};
head[f] = ecnt;
}
void dfs1(int u, int f, int d) {
dep[u] = d;
si[u] = 1;
fa[u] = f;
for (int i = head[u]; i; i = e[i].nx) {
int to = e[i].v;
if (to == f) continue;
mi[to] = std::min(mi[u], e[i].w);
dfs1(to, u, d + 1);
si[u] += si[to];
if (si[to] > si[son[u]]) son[u] = to;
}
}
void dfs2(int u, int topf) {
top[u] = topf;
id[u] = ++tim;
if (!son[u]) return;
dfs2(son[u], topf);
for (int i = head[u]; i; i = e[i].nx) {
int to = e[i].v;
if (to == fa[u] || to == son[u]) continue;
dfs2(to, to);
}
}
int LCA(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) std::swap(x, y);
x = fa[top[x]];
}
if (dep[x] > dep[y]) std::swap(x, y);
return x;
}
void build(int x) {
if (tp == 1) {
st[++tp] = x;
return;
}
int lca = LCA(st[tp], x);
if (lca == st[tp]) return;
while (tp > 1 && id[st[tp - 1]] >= id[lca]) {
v[st[tp - 1]].push_back(st[tp]);
tp--;
// printf("%d -> %d\n", st[tp - 1], st[tp]);
}
if (lca != st[tp]) {
v[lca].push_back(st[tp]);
st[tp] = lca;
}
st[++tp] = x;
}
int dp(int u) {
if (v[u].size() == 0) return mi[u];
int ans = 0;
for (int i = 0; i < (int)v[u].size(); i++) {
int to = v[u][i];
// printf("%d\n", to);
ans += dp(to);
}
v[u].clear();
return std::min(ans, mi[u]);
}
signed main() {
scanf("%lld", &n);
for (int i = 1; i < n; i++) {
scanf("%lld%lld%lld", &x, &y, &z);
add(x, y, z);
add(y, x, z);
}
memset(mi, 0x7f, sizeof(mi));
dfs1(1, -1, 1);
dfs2(1, 1);
scanf("%lld", &m);
while (m--) {
scanf("%lld", &k);
for (int i = 1; i <= k; i++) {
scanf("%lld", &a[i]);
}
std::sort(a + 1, a + 1 + k, cmp);
tp = 1;
st[tp] = 1;
for (int i = 1; i <= k; i++) {
build(a[i]);
}
// puts("");
// puts("----");
while (tp > 0) v[st[tp - 1]].push_back(st[tp]), tp--;
printf("%lld\n", dp(1));
// puts("----");
}
}

原谅我写这么简短的学习笔记

文章作者: RiverFun
文章链接: https://stevebraveman.github.io/blog/2019/08/17/95/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 RiverFun

评论
目录