# Description

P4178 Tree

# Solution

前两天 vp 时碰到一道需要在每个点维护一棵平衡树,同时需要树上启发式合并的题,然鹅发现自己不会写。

无意中发现这道题似乎也可以使用同样的做法。

大概思路是,对于每个点,先继承它的重儿子的平衡树,然后依次枚举所有的轻儿子,先统计答案,再合并平衡树。

统计答案与合并平衡树时,直接遍历轻儿子平衡树内所有的点即可。

具体来说,平衡树中存的值是子树内每个点到当前平衡树的根的点的距离。

假设当前处理到 xx 节点,我们先计算一端为 xx,另一端在重儿子子树内的合法点对数,即在重儿子的平衡树中查找小于等于 kdis(x,sonx)k - dis(x, son_x) 的点的个数。

然后我们把 xx 点插入到平衡树中。注意到这里平衡树的根从 sonxson_x 变成了 xx,这意味着需要一个平衡树整体加 dis(x,sonx)dis(x, son_x) 的操作,但是整体加还是太吃操作了,我们考虑把后面插入的节点都减去这个值也可以达到同样的效果。

接着枚举所有轻儿子,统计答案时遍历轻儿子的每一个点,取当前重儿子的平衡树中查询小于等于 kvalvdis(x,y)dis(x,sonx)k - val_v - dis(x, y) - dis(x, son_x) 的点个数即可(vv 是遍历到的点,yyvv 所在的子树的根,是 xx 的一个轻儿子)。

询问就处理完了。

合并也是同样的道理,只需要遍历轻儿子,将 valv+dis(x,y)dis(x,sonx)val_v + dis(x, y) - dis(x, son_x) 插入到重儿子的平衡树中即可。

实际上上面出现的所有距离的公式并不完全准确,只是举个例子,因为我们一直在搜索回溯,每次处理一个重儿子,距离都要整体加上一个数值,具体看代码理解吧。

#include <bits/stdc++.h>
#define fi first
#define se second
using namespace std;
const int N = 1e6 + 10;
typedef pair<int, int> pii;
int n, k;
vector <pii> G[N];
struct Tree{
    int siz[N], val[N], wei[N], ls[N], rs[N];
    int tot;
    
    inline void pushup(int x){
        siz[x] = siz[ls[x]] + siz[rs[x]] + 1;
    }
    inline void split(int x, int k, int &a, int &b){
        if(!x) return a = b = 0, void();
        if(k >= val[x]) a = x, split(rs[x], k, rs[x], b);
        else b = x, split(ls[x], k, a, ls[x]);
        pushup(x);
    }
    inline int merge(int x, int y){
        if(!x || !y) return x | y;
        if(wei[x] <= wei[y]){
            rs[x] = merge(rs[x], y);
            return pushup(x), x;
        }else{
            ls[y] = merge(x, ls[y]);
            return pushup(y), y;
        }
    }
    inline int newnode(int k){
        siz[++tot] = 1, val[tot] = k, wei[tot] = rand();
        return tot;
    }
    inline void ins(int &x, int k){
        int a, b; split(x, k, a, b);
        x = merge(a, merge(newnode(k), b));
    }
    inline int qry(int x, int k){
		int a, b; split(x, k, a, b);
		int res = siz[a];
		return x = merge(a, b), res;
    }
    inline int query(int &p, int q, int du, int dv){
        if(!q) return 0;
        int tmp = k - du - dv - val[q];
        int cnt = qry(p, tmp);
        return cnt + query(p, ls[q], du, dv) + query(p, rs[q], du, dv);
    }
    inline void Union(int &p, int q, int du, int dv){
        if(!q) return;
        ins(p, val[q] + dv - du);
        Union(p, ls[q], du, dv), Union(p, rs[q], du, dv);
    }
} t;
int siz[N], root[N], son[N];
int dis[N], sw[N];
int ans;
inline void dfs(int x, int fa){
    siz[x] = 1;
    for(auto y : G[x]){
        if(y.fi == fa) continue;
        dfs(y.fi, x), siz[x] += siz[y.fi];
        if(siz[y.fi] > siz[son[x]]) son[x] = y.fi, sw[x] = y.se;
    }
}
inline void dfs2(int x, int fa, int d){
    if(!son[x]){
        root[x] = t.newnode(0), dis[x] = d;
        return;
    }
    dfs2(son[x], x, sw[x]);
    root[x] = root[son[x]], dis[x] = dis[son[x]];
    ans += t.qry(root[x], k - dis[x]);
    // x ~ son[x]
    //  区间加
    t.ins(root[x], -dis[x]); // 插入 x
    for(auto y : G[x]){
        if(y.fi == fa || y.fi == son[x]) continue;
        dfs2(y.fi, x, y.se);
        ans += t.query(root[x], root[y.fi], dis[x], dis[y.fi]);
        t.Union(root[x], root[y.fi], dis[x], dis[y.fi]);
    }
    dis[x] += d;
}
signed main(){
#ifndef ONLINE_JUDGE
	freopen("test.in", "r", stdin);
	freopen("test.out", "w", stdout);
#endif
    ios :: sync_with_stdio(false);
    cin >> n;
    for(int i = 1; i < n; ++i){
        int u, v, w; cin >> u >> v >> w;
        G[u].push_back(pii(v, w)), G[v].push_back(pii(u, w));
    }
    cin >> k;
    dfs(1, 0), dfs2(1, 0, 0);
    cout << ans << endl;
    return 0;
}
更新于 阅读次数