# 树链剖分(重链剖分)

模板:洛谷 P3384 【模板】轻重链剖分/树链剖分

写在前面:强烈建议初学的同学如果不理解的话先手写一遍(像我一样),非常有助于理解

# 概念:

  • 重儿子: 一个节点所有儿子中最大的儿子

  • 轻儿子: 一个节点除重儿子之外的其他儿子

    特别地,叶子节点既没有重儿子也没有轻儿子

  • 重链: 重儿子连接形成的链叫重链

# 主要思想:

把一颗树拆成许多条链,把树上操作改为链上操作(区间操作),并利用线段树或树状数组等数据结构维护,以降低时间复杂度

每次进行操作时,先修改重链,再修改轻链

# 常见操作:

  • 将树从 xxyy 结点最短路径上所有节点的值都加上 zz

  • 求树从 xxyy 结点最短路径上所有节点的值之和

  • 将以 xx 为根节点的子树内所有节点值都加上 zz

  • 求以 xx 为根节点的子树内所有节点值之和

# 代码分析:

# dfs1

dfs1dfs1 要预处理出以下内容:

  • fa[x]:fa[x]: 节点 xx 的父亲。

  • dep[x]:dep[x]: 节点 xx 的深度。

  • siz[x]:siz[x]: 以节点 xx 为根的子树大小

  • son[x]:son[x]: 节点 xx 的重儿子编号

1
2
3
4
5
6
7
8
9
10
11
12
void dfs1(int x, int f){
fa[x] = f;
dep[x] = dep[f] + 1;
siz[x] = 1;
for(int i = head[x]; i; i = edge[i].nxt){
int y = edge[i].v;
if(y == f) continue;
dfs1(y, x);
siz[x] += siz[y];
if(siz[y] > siz[son[x]]) son[x] = y;
}
}

# dfs2

dfs2dfs2 也要预处理一些数据:

  • top[x]:top[x]: 节点 xx 所在重链的顶端

  • id[x]:id[x]: 把树改为链后节点 xx 新的编号

  • tw[cnt]:tw[cnt]: 把节点 xx 的权值存到新的编号中

1
2
3
4
5
6
7
8
9
10
11
12
void dfs2(int x, int topfa){
top[x] = topfa;
id[x] = ++cnt;
tw[cnt] = w[x];
if(!son[x]) return;
dfs2(son[x], topfa);
for(int i = head[x]; i; i = edge[i].nxt){
int y = edge[i].v;
if(y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
}

# 线段树部分

(包括 pushuppushuppushdownpushdownbuildbuildupdateupdatequeryquery

就是普通的线段树维护区间加,区间查询

就不多讲了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
void pushup(int rt){
sum[rt] = (sum[ls] + sum[rs]) % mod;
}

void pushdown(int l, int r, int rt){
if(lazy[rt]){
int mid = (l + r) >> 1;
sum[ls] = (sum[ls] + lazy[rt] * (mid - l + 1) % mod) % mod;
sum[rs] = (sum[rs] + lazy[rt] * (r - mid) % mod) % mod;
lazy[ls] = (lazy[ls] + lazy[rt]) % mod;
lazy[rs] = (lazy[rs] + lazy[rt]) % mod;
lazy[rt] = 0;
}
}

void build(int l, int r, int rt){
if(l == r){
sum[rt] = tw[l] % mod;
return;
}
int mid = (l + r) >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
pushup(rt);
}

void update(int L, int R, int k, int l, int r, int rt){
if(L <= l && r <= R){
sum[rt] = (sum[rt] + k * (r -l + 1) % mod) % mod;
lazy[rt] = (lazy[rt] + k) % mod;
return;
}
pushdown(l, r, rt);
int mid = (l + r) >> 1;
if(L <= mid) update(L, R, k, l, mid, ls);
if(R > mid) update(L, R, k, mid + 1, r, rs);
pushup(rt);
}

int query(int L, int R, int l, int r, int rt){
if(L <= l && r <= R)
return sum[rt];
pushdown(l, r, rt);
int mid = (l + r) >> 1;
int res = 0;
if(L <= mid) res = (res + query(L, R, l, mid, ls)) % mod;
if(R > mid) res = (res + query(L, R, mid + 1, r, rs)) % mod;
return res;
}

# update_Range

(链上修改)

这里是重点

因为一条重链上的点的编号一定是连续的,所以重链上的点相当于是一段区间,我们可以利用线段树来进行区间维护及查询,以减少时间复杂度

具体看代码

1
2
3
4
5
6
7
8
9
10
void update_Range(int x, int y, int k){
k %= mod;
while(top[x] != top[y]){ //x和y不在同一条重链上
if(dep[top[x]] < dep[top[y]]) swap(x, y); //令深度大的点往上跳
update(id[top[x]], id[x], k, 1, n, 1);
x = fa[top[x]]; //然后自己跳到当前重链的父节点上(进入下一条重链)
}
if(id[x] > id[y]) swap(x, y); //这时两个点已经在同一条重链上
update(id[x], id[y], k, 1, n, 1); //直接update两个点之间的点就好
}

# query_Range

(链上查询)

与链上修改一个道理,直接看代码吧

1
2
3
4
5
6
7
8
9
10
11
12
//查询同理
int query_Range(int x, int y){
int res = 0;
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
res = (res + query(id[top[x]], id[x], 1, n, 1)) % mod;
x = fa[top[x]];
}
if(id[x] > id[y]) swap(x, y);
res = (res + query(id[x], id[y], 1, n, 1)) % mod;
return res;
}

# update_Son

(子树修改)

根据深搜的性质,一棵子树中的点的编号也一定是连续的,所以……

1
2
3
void update_Son(int x, int k){
update(id[x], id[x] + siz[x] - 1, k, 1, n, 1); //这里区间修改的右端点就是id[x] + siz[x] - 1
}

# query_Son

(子树查询)

这里跟子树修改一样

1
2
3
int query_Son(int x){
return query(id[x], id[x] + siz[x] - 1, 1, n, 1);
}

# 完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#include <iostream>
#define ls rt << 1
#define rs (rt << 1) | 1

using namespace std;

const int N = 1e5 + 10;
struct node{
int v, nxt;
}edge[N << 1];
int head[N], tot;
int n, m, root, mod, w[N];
int dep[N], fa[N], siz[N], son[N]; //dfs1需维护的
int top[N], id[N], tw[N], cnt; //dfs2需维护的
int sum[N << 2], lazy[N << 2]; //经典线段树

inline int read(){
int x = 0, f = 1;
char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
while(ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x * f;
}

inline void add(int x, int y){
edge[++tot] = (node) {y, head[x]};
head[x] = tot;
}

//------------------------------------------------以下是dfs预处理
void dfs1(int x, int f){
fa[x] = f;
dep[x] = dep[f] + 1;
siz[x] = 1;
for(int i = head[x]; i; i = edge[i].nxt){
int y = edge[i].v;
if(y == f) continue;
dfs1(y, x);
siz[x] += siz[y];
if(siz[y] > siz[son[x]]) son[x] = y;
}
}

void dfs2(int x, int topfa){
top[x] = topfa;
id[x] = ++cnt;
tw[cnt] = w[x];
if(!son[x]) return;
dfs2(son[x], topfa);
for(int i = head[x]; i; i = edge[i].nxt){
int y = edge[i].v;
if(y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
}

//------------------------------------------------以下是线段树
void pushup(int rt){
sum[rt] = (sum[ls] + sum[rs]) % mod;
}

void pushdown(int l, int r, int rt){
if(lazy[rt]){
int mid = (l + r) >> 1;
sum[ls] = (sum[ls] + lazy[rt] * (mid - l + 1) % mod) % mod;
sum[rs] = (sum[rs] + lazy[rt] * (r - mid) % mod) % mod;
lazy[ls] = (lazy[ls] + lazy[rt]) % mod;
lazy[rs] = (lazy[rs] + lazy[rt]) % mod;
lazy[rt] = 0;
}
}

void build(int l, int r, int rt){
if(l == r){
sum[rt] = tw[l] % mod;
return;
}
int mid = (l + r) >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
pushup(rt);
}

void update(int L, int R, int k, int l, int r, int rt){
if(L <= l && r <= R){
sum[rt] = (sum[rt] + k * (r -l + 1) % mod) % mod;
lazy[rt] = (lazy[rt] + k) % mod;
return;
}
pushdown(l, r, rt);
int mid = (l + r) >> 1;
if(L <= mid) update(L, R, k, l, mid, ls);
if(R > mid) update(L, R, k, mid + 1, r, rs);
pushup(rt);
}

int query(int L, int R, int l, int r, int rt){
if(L <= l && r <= R)
return sum[rt];
pushdown(l, r, rt);
int mid = (l + r) >> 1;
int res = 0;
if(L <= mid) res = (res + query(L, R, l, mid, ls)) % mod;
if(R > mid) res = (res + query(L, R, mid + 1, r, rs)) % mod;
return res;
}

//------------------------------------------------以下是链上修改及查询
void update_Range(int x, int y, int k){
k %= mod;
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
update(id[top[x]], id[x], k, 1, n, 1);
x = fa[top[x]];
}
if(id[x] > id[y]) swap(x, y);
update(id[x], id[y], k, 1, n, 1);
}

int query_Range(int x, int y){
int res = 0;
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
res = (res + query(id[top[x]], id[x], 1, n, 1)) % mod;
x = fa[top[x]];
}
if(id[x] > id[y]) swap(x, y);
res = (res + query(id[x], id[y], 1, n, 1)) % mod;
return res;
}

//------------------------------------------------以下是子树上修改及查询
void update_Son(int x, int k){
update(id[x], id[x] + siz[x] - 1, k, 1, n, 1);
}

int query_Son(int x){
return query(id[x], id[x] + siz[x] - 1, 1, n, 1);
}

int main(){
n = read(), m = read(), root = read(), mod = read();
for(int i = 1; i <= n; i++)
w[i] = read();
for(int i = 1; i < n; i++){
int u, v;
u = read(), v = read();
add(u, v), add(v, u);
}
dfs1(root, 0);
dfs2(root, root);
build(1, n, 1);
while(m--){
int op, x, y, k;
op = read();
if(op == 1){
x = read(), y = read(), k = read();
update_Range(x, y, k);
}else if(op == 2){
x = read(), y = read();
printf("%d\n", query_Range(x, y));
}else if(op == 3){
x = read(), k = read();
update_Son(x, k);
}else{
x = read();
printf("%d\n", query_Son(x));
}
}
return 0;
}