# 树链剖分(重链剖分)

模板:洛谷 P3384 【模板】轻重链剖分 / 树链剖分

写在前面:强烈建议初学的同学如果不理解的话先手写一遍(像我一样),非常有助于理解

# 概念:

  • 重儿子: 一个节点所有儿子中最大的儿子

  • 轻儿子: 一个节点除重儿子之外的其他儿子

    特别地,叶子节点既没有重儿子也没有轻儿子

  • 重链: 重儿子连接形成的链叫重链

# 主要思想:

把一颗树拆成许多条链,把树上操作改为链上操作(区间操作),并利用线段树或树状数组等数据结构维护,以降低时间复杂度

每次进行操作时,先修改重链,再修改轻链

# 常见操作:

  • 将树从 xxyy 结点最短路径上所有节点的值都加上 zz

  • 求树从 xxyy 结点最短路径上所有节点的值之和

  • 将以 xx 为根节点的子树内所有节点值都加上 zz

  • 求以 xx 为根节点的子树内所有节点值之和

# 代码分析:

# dfs1

dfs1dfs1 要预处理出以下内容:

  • fa[x]:fa[x]: 节点 xx 的父亲。

  • dep[x]:dep[x]: 节点 xx 的深度。

  • siz[x]:siz[x]: 以节点 xx 为根的子树大小

  • son[x]:son[x]: 节点 xx 的重儿子编号

void dfs1(int x, int f){
    fa[x] = f;
    dep[x] = dep[f] + 1;
    siz[x] = 1;
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        if(y == f) continue;
        dfs1(y, x);
        siz[x] += siz[y];
        if(siz[y] > siz[son[x]]) son[x] = y;
    }
}

# dfs2

dfs2dfs2 也要预处理一些数据:

  • top[x]:top[x]: 节点 xx 所在重链的顶端

  • id[x]:id[x]: 把树改为链后节点 xx 新的编号

  • tw[cnt]:tw[cnt]: 把节点 xx 的权值存到新的编号中

void dfs2(int x, int topfa){
    top[x] = topfa;
    id[x] = ++cnt;
    tw[cnt] = w[x];
    if(!son[x]) return;
    dfs2(son[x], topfa);
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        if(y == fa[x] || y == son[x]) continue;
        dfs2(y, y);
    }
}

# 线段树部分

(包括 pushuppushuppushdownpushdownbuildbuildupdateupdatequeryquery

就是普通的线段树维护区间加,区间查询

就不多讲了

void pushup(int rt){
    sum[rt] = (sum[ls] + sum[rs]) % mod;
}
void pushdown(int l, int r, int rt){
    if(lazy[rt]){
        int mid = (l + r) >> 1;
        sum[ls] = (sum[ls] + lazy[rt] * (mid - l + 1) % mod) % mod;
        sum[rs] = (sum[rs] + lazy[rt] * (r - mid) % mod) % mod;
        lazy[ls] = (lazy[ls] + lazy[rt]) % mod;
        lazy[rs] = (lazy[rs] + lazy[rt]) % mod;
        lazy[rt] = 0;
    }
}
void build(int l, int r, int rt){
    if(l == r){
        sum[rt] = tw[l] % mod;
        return;
    }
    int mid = (l + r) >> 1;
    build(l, mid, ls);
    build(mid + 1, r, rs);
    pushup(rt);
}
void update(int L, int R, int k, int l, int r, int rt){
    if(L <= l && r <= R){
        sum[rt] = (sum[rt] + k * (r -l + 1) % mod) % mod;
        lazy[rt] = (lazy[rt] + k) % mod;
        return;
    }
    pushdown(l, r, rt);
    int mid = (l + r) >> 1;
    if(L <= mid) update(L, R, k, l, mid, ls);
    if(R > mid) update(L, R, k, mid + 1, r, rs);
    pushup(rt);
}
int query(int L, int R, int l, int r, int rt){
    if(L <= l && r <= R)
        return sum[rt];
    pushdown(l, r, rt);
    int mid = (l + r) >> 1;
    int res = 0;
    if(L <= mid) res = (res + query(L, R, l, mid, ls)) % mod;
    if(R > mid) res = (res + query(L, R, mid + 1, r, rs)) % mod;
    return res;
}

# update_Range

(链上修改)

这里是重点

因为一条重链上的点的编号一定是连续的,所以重链上的点相当于是一段区间,我们可以利用线段树来进行区间维护及查询,以减少时间复杂度

具体看代码

void update_Range(int x, int y, int k){
   k %= mod;
   while(top[x] != top[y]){    //x 和 y 不在同一条重链上
       if(dep[top[x]] < dep[top[y]]) swap(x, y);  // 令深度大的点往上跳
       update(id[top[x]], id[x], k, 1, n, 1);
       x = fa[top[x]];    // 然后自己跳到当前重链的父节点上(进入下一条重链)
   }
   if(id[x] > id[y]) swap(x, y);    // 这时两个点已经在同一条重链上 
   update(id[x], id[y], k, 1, n, 1);  // 直接 update 两个点之间的点就好
}

# query_Range

(链上查询)

与链上修改一个道理,直接看代码吧

// 查询同理
int query_Range(int x, int y){
    int res = 0;
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        res = (res + query(id[top[x]], id[x], 1, n, 1)) % mod;
        x = fa[top[x]];
    }
    if(id[x] > id[y]) swap(x, y);
    res = (res + query(id[x], id[y], 1, n, 1)) % mod;
    return res;
}

# update_Son

(子树修改)

根据深搜的性质,一棵子树中的点的编号也一定是连续的,所以……

void update_Son(int x, int k){
    update(id[x], id[x] + siz[x] - 1, k, 1, n, 1);    // 这里区间修改的右端点就是 id [x] + siz [x] - 1
}

# query_Son

(子树查询)

这里跟子树修改一样

int query_Son(int x){
    return query(id[x], id[x] + siz[x] - 1, 1, n, 1);
}

# 完整代码

#include <iostream>
#define ls rt << 1
#define rs (rt << 1) | 1
using namespace std;
const int N = 1e5 + 10;
struct node{
    int v, nxt;
}edge[N << 1];
int head[N], tot;
int n, m, root, mod, w[N];
int dep[N], fa[N], siz[N], son[N];  //dfs1 需维护的
int top[N], id[N], tw[N], cnt;      //dfs2 需维护的
int sum[N << 2], lazy[N << 2];      // 经典线段树
inline int read(){
    int x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
    return x * f;
}
inline void add(int x, int y){
    edge[++tot] = (node) {y, head[x]};
    head[x] = tot;
}
//------------------------------------------------ 以下是 dfs 预处理
void dfs1(int x, int f){
    fa[x] = f;
    dep[x] = dep[f] + 1;
    siz[x] = 1;
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        if(y == f) continue;
        dfs1(y, x);
        siz[x] += siz[y];
        if(siz[y] > siz[son[x]]) son[x] = y;
    }
}
void dfs2(int x, int topfa){
    top[x] = topfa;
    id[x] = ++cnt;
    tw[cnt] = w[x];
    if(!son[x]) return;
    dfs2(son[x], topfa);
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        if(y == fa[x] || y == son[x]) continue;
        dfs2(y, y);
    }
}
//------------------------------------------------ 以下是线段树
void pushup(int rt){
    sum[rt] = (sum[ls] + sum[rs]) % mod;
}
void pushdown(int l, int r, int rt){
    if(lazy[rt]){
        int mid = (l + r) >> 1;
        sum[ls] = (sum[ls] + lazy[rt] * (mid - l + 1) % mod) % mod;
        sum[rs] = (sum[rs] + lazy[rt] * (r - mid) % mod) % mod;
        lazy[ls] = (lazy[ls] + lazy[rt]) % mod;
        lazy[rs] = (lazy[rs] + lazy[rt]) % mod;
        lazy[rt] = 0;
    }
}
void build(int l, int r, int rt){
    if(l == r){
        sum[rt] = tw[l] % mod;
        return;
    }
    int mid = (l + r) >> 1;
    build(l, mid, ls);
    build(mid + 1, r, rs);
    pushup(rt);
}
void update(int L, int R, int k, int l, int r, int rt){
    if(L <= l && r <= R){
        sum[rt] = (sum[rt] + k * (r -l + 1) % mod) % mod;
        lazy[rt] = (lazy[rt] + k) % mod;
        return;
    }
    pushdown(l, r, rt);
    int mid = (l + r) >> 1;
    if(L <= mid) update(L, R, k, l, mid, ls);
    if(R > mid) update(L, R, k, mid + 1, r, rs);
    pushup(rt);
}
int query(int L, int R, int l, int r, int rt){
    if(L <= l && r <= R)
        return sum[rt];
    pushdown(l, r, rt);
    int mid = (l + r) >> 1;
    int res = 0;
    if(L <= mid) res = (res + query(L, R, l, mid, ls)) % mod;
    if(R > mid) res = (res + query(L, R, mid + 1, r, rs)) % mod;
    return res;
}
//------------------------------------------------ 以下是链上修改及查询
void update_Range(int x, int y, int k){
    k %= mod;
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        update(id[top[x]], id[x], k, 1, n, 1);
        x = fa[top[x]];
    }
    if(id[x] > id[y]) swap(x, y);
    update(id[x], id[y], k, 1, n, 1);
}
int query_Range(int x, int y){
    int res = 0;
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        res = (res + query(id[top[x]], id[x], 1, n, 1)) % mod;
        x = fa[top[x]];
    }
    if(id[x] > id[y]) swap(x, y);
    res = (res + query(id[x], id[y], 1, n, 1)) % mod;
    return res;
}
//------------------------------------------------ 以下是子树上修改及查询
void update_Son(int x, int k){
    update(id[x], id[x] + siz[x] - 1, k, 1, n, 1);
}
int query_Son(int x){
    return query(id[x], id[x] + siz[x] - 1, 1, n, 1);
}
int main(){
    n = read(), m = read(), root = read(), mod = read();
    for(int i = 1; i <= n; i++)
        w[i] = read();
    for(int i = 1; i < n; i++){
        int u, v;
        u = read(), v = read();
        add(u, v), add(v, u);
    }
    dfs1(root, 0);
    dfs2(root, root);
    build(1, n, 1);
    while(m--){
        int op, x, y, k;
        op = read();
        if(op == 1){
            x = read(), y = read(), k = read();
            update_Range(x, y, k);
        }else if(op == 2){
            x = read(), y = read();
            printf("%d\n", query_Range(x, y));
        }else if(op == 3){
            x = read(), k = read();
            update_Son(x, k);
        }else{
            x = read();
            printf("%d\n", query_Son(x));
        }
    }
    return 0;
}