树状数组 + 主席树,双线段树,树状数组 + 分块,最短路 + 拓扑排序

数据结构场,4 道题 3 道数据结构,还不给大样例 /tuu

# A. Count

考场上题面有误, < 号写成了 <= 。导致本来可以用主席树轻松维护的数据多出来一个乱七八糟的东西,于是只能写了个 O(nm)O(nm) 的暴力,结果条件不对遗憾爆零。

真的是服了。(另外有人能在题面有误的情况下爆切这道题我只能甘拜下风)

首先一定要搞清楚这题什么情况下才是好点对。搞清楚之后下面就是如何求好点对个数。

预处理 prei,nxtipre_i, nxt_ipreipre_i 表示 ii 左边大于等于 aia_i 的最靠右的位置,nxtinxt_i 同理。

对于 i=j1i = j - 1 的好点对单独处理,以下只考虑两个点之间距离大于 1 的情况。

对于询问 lrl \sim r,枚举 i[l,r]i \in [l, r],如果 prei<i1pre_i < i - 1preilpre_i \geq l,那么答案加 1;如果 nxti>i+1nxt_i > i + 1nxtirnxt_i \leq r 答案再加 1。

需要注意的是,ai=aja_i = a_j 的情况会重复计算。解决方法是在计算 preipre_i 是否有贡献时多加一个判断:apreiaia_{pre_i} \neq a_i。这样一来,nxtinxt_i 会计算等于的情况,而 preipre_i 不会。

然后就可以得到 40pts\text{40pts} 的好成绩。

明显是可以优化的。

对于 preii1pre_i \geq i - 1aprei=aia_{pre_i} = a_i 的,直接令 prei=0pre_i = 0,这样对答案就一定没有贡献了,nxtinxt_i 只有在 nxtii+1nxt_i \leq i + 1 时,令其等于 00

先只考虑 preipre_i 的贡献。

preipre_i 建一棵主席树,贡献就是 lrl \sim r 区间内,大于等于 ll 的数的个数。

nxtinxt_i 的贡献就是 lrl \sim r 内小于等于 rr 的个数。

再从头说一遍。

首先预处理,我使用的树状数组,实际上用单调栈就可以。

然后对于 preipre_inxtinxt_i 分别建一棵主席树,查询即可。

#include <bits/stdc++.h>
#define int long long
#define pb push_back
#define pc putchar
#define mid ((l + r) >> 1)
using namespace std;
namespace IO{
    inline int read(){
        int x = 0, f = 1;
        char ch = getchar();
        while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x * f;
    }
    template <typename T> inline void write(T x){
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;
const int N = 3e5 + 10;
const int inf = 1e9;
int n, m, typ, tot;
int lst;
int a[N], b[N];
int pre[N], nxt[N];
int c[N];
inline void add1(int x, int y){
    for(; x; x -= x & (-x)) c[x] = max(c[x], y);
}
inline int qry1(int x){
    int res = 0;
    for(; x <= tot; x += x & (-x)) res = max(res, c[x]);
    return res;
}
inline void add2(int x, int y){
    for(; x; x -= x & (-x)) c[x] = min(c[x], y);
}
inline int qry2(int x){
    int res = N;
    for(; x <= tot; x += x & (-x)) res = min(res, c[x]);
    return res == N ? 0 : res;
}
namespace T1{
    int sum[N << 5], ls[N << 5], rs[N << 5];
    int rt[N], cnt, now;
    inline int update(int pre, int l, int r, int x, int k){
        int p = ++cnt;
        sum[p] = sum[pre] + 1, ls[p] = ls[pre], rs[p] = rs[pre];
        if(l == r) return p;
        if(x <= mid) ls[p] = update(ls[pre], l, mid, x, k);
        else rs[p] = update(rs[pre], mid + 1, r, x, k);
        return p;
    }
    inline int query(int u, int v, int k, int l, int r){
        if(l == r) return sum[v] - sum[u];
        if(k <= mid) return sum[rs[v]] - sum[rs[u]] + query(ls[u], ls[v], k, l, mid);
        else return query(rs[u], rs[v], k, mid + 1, r);
    }
}
namespace T2{
    int sum[N << 5], ls[N << 5], rs[N << 5];
    int rt[N], cnt, now;
    inline int update(int pre, int l, int r, int x, int k){
        int p = ++cnt;
        sum[p] = sum[pre] + 1, ls[p] = ls[pre], rs[p] = rs[pre];
        if(l == r) return p;
        if(x <= mid) ls[p] = update(ls[pre], l, mid, x, k);
        else rs[p] = update(rs[pre], mid + 1, r, x, k);
        return p;
    }
    inline int query(int u, int v, int k, int l, int r){
        if(l == r) return sum[v] - sum[u];
        if(k > mid) return sum[ls[v]] - sum[ls[u]] + query(rs[u], rs[v], k, mid + 1, r);
        else return query(ls[u], ls[v], k, l, mid);
    }
}
signed main(){
#ifndef ONLINE_JUDGE
    freopen("test.in", "r", stdin);
    freopen("test.out", "w", stdout);
#endif
    n = read(), m = read(), typ = read();
    for(int i = 1; i <= n; ++i) a[i] = b[i] = read();
    sort(b + 1, b + 1 + n);
    tot = unique(b + 1, b + 1 + n) - b - 1;
    for(int i = 1; i <= n; ++i) a[i] = lower_bound(b + 1, b + 1 + tot, a[i]) - b;
    for(int i = 1; i <= n; ++i)
        pre[i] = qry1(a[i]), add1(a[i], i);
    memset(c, 0x3f, sizeof(c));
    for(int i = n; i >= 1; --i)
        nxt[i] = qry2(a[i]), add2(a[i], i);
    for(int i = 1; i <= n; ++i){
        if(a[pre[i]] == a[i] || pre[i] >= i - 1) pre[i] = 0;
        if(nxt[i] <= i + 1) nxt[i] = n + 1;
    }
    for(int i = 1; i <= n; ++i){
        T1::rt[i] = T1::update(T1::rt[i - 1], 0, n + 1, pre[i], 1);
        T2::rt[i] = T2::update(T2::rt[i - 1], 0, n + 1, nxt[i], 1);
    }
    while(m--){
        int l = read(), r = read();
        if(typ) l = (l + lst - 1) % n + 1, r = (r + lst - 1) % n + 1;
        if(l > r) swap(l, r);
        int res = T1::query(T1::rt[l - 1], T1::rt[r], l, 0, n + 1) + T2::query(T2::rt[l - 1], T2::rt[r], r, 0, n + 1);
        write(lst = (res + r - l)), puts("");
    }
    return 0;
}

# B. Abs

考场上写了个随机数据下很快的代码,实测得了 60pts\text{60pts} 感觉还行。

就是区间修改时对于负数一直递归下去单点修改,很暴力。

下面是正解。

维护两个线段树,一个维护正数,另一个维护负数,维护负数的线段树同时维护一个负数的最大值。

区间加时如果最大的负数加上这个数之后会变成正数,那么递归下去单点修改。

实现细节。

要维护一个 cntrtcnt_{rt} 表示区间内正数 / 负数个数,便于区间加时计算整个区间到底加了多少。

维护负数的线段树要可以把负数作为正数维护,然后维护的最大值变成维护最小值,区间加变成区间减。当然也可以正着写。

#include <bits/stdc++.h>
#define int long long
#define pb push_back
#define pc putchar
#define mid ((l + r) >> 1)
#define ls rt << 1
#define rs rt << 1 | 1
using namespace std;
namespace IO{
    inline int read(){
        int x = 0, f = 1;
        char ch = getchar();
        while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x * f;
    }
    template <typename T> inline void write(T x){
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;
const int N = 2e5 + 10;
const int inf = 1e18;
int n, m;
int a[N];
vector <int> g[N];
int fa[N], siz[N], dep[N], son[N];
inline void dfs1(int x, int f){
    fa[x] = f, dep[x] = dep[f] + 1, siz[x] = 1;
    for(auto y : g[x]){
        if(y == f) continue;
        dfs1(y, x), siz[x] += siz[y];
        if(siz[y] > siz[son[x]]) son[x] = y;
    }
}
int dfn[N], top[N], id[N], tim;
inline void dfs2(int x, int topfa){
    top[x] = topfa, dfn[x] = ++tim, id[tim] = x;
    if(son[x]) dfs2(son[x], topfa);
    for(auto y : g[x])
        if(y != fa[x] && y != son[x]) dfs2(y, y);
}
int sum[N << 2][2], cnt[N << 2][2], tag[N << 2][2], mn[N << 2];
inline void pushup(int rt){
    sum[rt][0] = (cnt[ls][0] ? sum[ls][0] : 0) + (cnt[rs][0] ? sum[rs][0] : 0);
    sum[rt][1] = (cnt[ls][1] ? sum[ls][1] : 0) + (sum[rs][1] ? sum[rs][1] : 0);
    mn[rt] = min(mn[ls], mn[rs]);
    cnt[rt][0] = cnt[ls][0] + cnt[rs][0];
    cnt[rt][1] = cnt[ls][1] + cnt[rs][1];
}
inline void pushtag(int x, int y, int k){
    if(!cnt[x][y]) return;
    sum[x][y] += cnt[x][y] * k, tag[x][y] += k;
    if(y == 1) mn[x] += k;
}
inline void pushdown(int rt){
    if(tag[rt][0]){
        pushtag(ls, 0, tag[rt][0]), pushtag(rs, 0, tag[rt][0]);
        tag[rt][0] = 0;
    }
    if(tag[rt][1]){
        pushtag(ls, 1, tag[rt][1]), pushtag(rs, 1, tag[rt][1]);
        tag[rt][1] = 0;
    }
}
inline void build(int l, int r, int rt){
    if(l == r){
        if(a[id[l]] >= 0) sum[rt][0] = a[id[l]], cnt[rt][0] = 1, mn[rt] = inf;
        else sum[rt][1] = -a[id[l]], cnt[rt][1] = 1, mn[rt] = -a[id[l]];
        return;
    }
    build(l, mid, ls), build(mid + 1, r, rs);
    pushup(rt);
}
inline void upd(int L, int R, int k, int l, int r, int rt){
    if(L > r || R < l) return;
    if(L <= l && r <= R && mn[rt] >= k) return pushtag(rt, 0, k), pushtag(rt, 1, -k), void();
    pushdown(rt);
    if(l == r){
        sum[rt][0] = k - mn[rt], cnt[rt][0] = 1;
        sum[rt][1] = 0, mn[rt] = inf, cnt[rt][1] = 0;
        return;
    }
    upd(L, R, k, l, mid, ls), upd(L, R, k, mid + 1, r, rs);
    pushup(rt);
}
inline int qry(int L, int R, int l, int r, int rt){
    if(L > r || R < l) return 0;
    if(L <= l && r <= R) return sum[rt][0] + sum[rt][1];
    pushdown(rt);
    return qry(L, R, l, mid, ls) + qry(L, R, mid + 1, r, rs);
}
inline void update(int x, int y, int k){
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        upd(dfn[top[x]], dfn[x], k, 1, n, 1);
        x = fa[top[x]];
    }
    if(dfn[x] > dfn[y]) swap(x, y);
    upd(dfn[x], dfn[y], k, 1, n, 1);
}
inline int query(int x, int y){
    int res = 0;
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        res += qry(dfn[top[x]], dfn[x], 1, n, 1);
        x = fa[top[x]];
    }
    if(dfn[x] > dfn[y]) swap(x, y);
    res += qry(dfn[x], dfn[y], 1, n, 1);
    return res;
}
signed main(){
#ifndef ONLINE_JUDGE
    freopen("test.in", "r", stdin);
    freopen("test.out", "w", stdout);
#endif
    n = read(), m = read();
    for(int i = 1; i <= n; ++i) a[i] = read();
    for(int i = 1, u, v; i < n; ++i)
        u = read(), v = read(), g[u].pb(v), g[v].pb(u);
    dfs1(1, 0), dfs2(1, 1), build(1, n, 1);
    while(m--){
        int op = read(), u = read(), v = read();
        if(op == 1) update(u, v, read());
        else write(query(u, v)), puts("");
    }
    return 0;
}

# C. 普通计算姬

暴力可得 90pts\text{90pts} 是我没想到的。

一个点的权值 valxval_x 对区间 lrl \sim r 的贡献是 xx 及它的祖先在 lrl \sim r 出现个数乘上 valxval_x

考虑序列分块。

ti,jt_{i, j} 表示 ii 及祖先在第 jj 块出现次数。

这个东西跑一遍 dfsdfs 预处理一下就行。

再维护一个 sumisum_i 表示第 ii 块的答案,根据 ti,jt_{i, j} 计算一下就行。

再来看如何维护散块;

考虑使用树状数组,往树状数组里插入每个点的权值 axa_x,然后查询 dfsdfs 序上 stxedxst_x \sim ed_x 的区间和就是 xx 的子树和。

块长为 n\sqrt n 时,复杂度为 O(nnlogn)O(n\sqrt n \log n),不过可以根号平衡。

实际写的时候多测测块长吧。

#include <bits/stdc++.h>
#define ull unsigned long long
#define pb push_back
#define pc putchar
#define mid ((l + r) >> 1)
#define ls rt << 1
#define rs rt << 1 | 1
using namespace std;
namespace IO{
    inline int read(){
        int x = 0, f = 1;
        char ch = getchar();
        while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x * f;
    }
    template <typename T> inline void write(T x){
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;
const int N = 1e5 + 10;
const int M = 500;
int n, m, rt, B, tot;
int a[N];
vector <int> g[N];
int be[N], L[N], R[N], t[N][M];
ull sum[M];
int fa[N], st[N], ed[N], tim;
inline void dfs(int x, int f){
    fa[x] = f, st[x] = ++tim;
    for(int i = 1; i <= tot; ++i) t[x][i] = t[f][i];
    t[x][be[x]]++;
    for(auto y : g[x]) if(y != f) dfs(y, x);
    ed[x] = tim;
}
ull c[N];
inline void add(int x, int y){
    for(; x <= n; x += x & (-x)) c[x] += y;
}
inline ull qry(int x){
    ull res = 0;
    for(; x; x -= x & (-x)) res += c[x];
    return res;
}
inline ull qry(int l, int r){
    return qry(r) - qry(l - 1);
}
inline void init(){
    B = sqrt(n); tot = n / B;
    if(n % B) tot++;
    for(int i = 1; i <= n; ++i) be[i] = (i - 1) / B + 1;
    for(int i = 1; i <= tot; ++i) L[i] = R[i - 1] + 1, R[i] = i * B;
    R[tot] = n, dfs(rt, 0);
    for(int i = 1; i <= n; ++i)
        for(int j = 1; j <= tot; ++j)
            sum[j] += 1ull * a[i] * t[i][j];
    for(int i = 1; i <= n; ++i) add(st[i], a[i]);
}
inline void update(int x, int k){
    for(int i = 1; i <= tot; ++i) sum[i] += 1ull * k * t[x][i];
    add(st[x], k);
}
inline ull query(int l, int r){
    ull res = 0;
    if(be[l] == be[r]){
        for(int i = l; i <= r; ++i) res += qry(st[i], ed[i]);
        return res;
    }
    for(int i = be[l] + 1; i <= be[r] - 1; ++i) res += sum[i];
    for(int i = l; i <= R[be[l]]; ++i) res += qry(st[i], ed[i]);
    for(int i = L[be[r]]; i <= r; ++i) res += qry(st[i], ed[i]);
    return res;
}
signed main(){
#ifndef ONLINE_JUDGE
    freopen("test.in", "r", stdin);
    freopen("test.out", "w", stdout);
#endif
    n = read(), m = read();
    for(int i = 1; i <= n; ++i) a[i] = read();
    for(int i = 1, u, v; i <= n; ++i){
        u = read(), v = read(), g[u].pb(v), g[v].pb(u);
        if(u == 0) rt = v;
    }
    init();
    while(m--){
        int op = read(), u = read(), v = read();
        if(op == 1) update(u, v - a[u]), a[u] = v;
        else write(query(u, v)), puts("");
    }
    return 0;
}

# D. Road

首先有一条性质,如果最短路上有:

disv=disu+w(u,v)dis_v = dis_u + w(u, v)

那么 (u,v)(u, v) 这条边就是在最短路上的。

先枚举一个源点 ss,从 ss 跑最短路,然后建最短路图,最短路图就是如果一条边是在最短路上的就加进去。

最短路图是一个 DAG\text{DAG},对于上面的一条最短路上的边 (u,v)(u, v),这条边被经过的次数就是最短路图上从 ss 考试能到达 uu 的路径数乘上 vv 能到达的点数。

对于从 ssuu 的路径数,直接跑一遍拓扑排序即可。

对于 vv 能到达的点数,反着跑一遍拓扑(可以在正着拓扑时开个栈记录每个点,反着遍历栈即可)。

具体看代码吧。

#include <bits/stdc++.h>
#define pb push_back
#define fi first
#define se second
using namespace std;
namespace IO{
    inline int read(){
        int x = 0, f = 1;
        char ch = getchar();
        while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x * f;
    }
    template <typename T> inline void write(T x){
        if(x < 0) putchar('-'), x = -x;
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;
const int N = 1.5e5;
const int mod = 1e9 + 7;
int n, m;
struct node{
    int u, v, w, nxt;
}e[N], edge[N];
int head[N], tot;
inline int add(int x) {return x >= mod ? x - mod : x;}
inline int sub(int x) {return x < 0 ? x + mod : x;}
inline int mul(int x, int y) {return 1ll * x * y % mod;}
inline void add(int x, int y, int z){
    edge[++tot] = (node){x, y, z, head[x]};
    head[x] = tot;
}
typedef pair<int, int> P;
int dis[N], in[N], flag[N];
inline void dijkstra(int s){
    priority_queue <P, vector<P>, greater<P> > q;
    memset(dis, 0x3f, sizeof(dis));
    q.push(P(0, s)), dis[s] = 0;
    while(!q.empty()){
        P p = q.top(); q.pop();
        int x = p.se;
        if(dis[x] < p.fi) continue;
        for(int i = head[x], y; i; i = edge[i].nxt)
            if(dis[y = edge[i].v] > dis[x] + edge[i].w)
                dis[y] = dis[x] + edge[i].w, q.push(P(dis[y], y));
    }
}
int c1[N], c2[N], ans[N];
int stk[N], top;
inline void topo(int s){
    memset(c1, 0, sizeof(c1));
    memset(c2, 0, sizeof(c2));
    queue <int> q; top = 0;
    q.push(s), c1[s] = 1;
    while(!q.empty()){
        int x = q.front(); q.pop();
        stk[++top] = x;
        for(int i = head[x]; i; i = edge[i].nxt){
            if(!flag[i]) continue;
            int y = edge[i].v;
            c1[y] = add(c1[y] + c1[x]);
            if(!(--in[y])) q.push(y);
        }
    }
    while(top){
        int x = stk[top--]; c2[x]++;
        for(int i = head[x]; i; i = edge[i].nxt)
            if(flag[i]) c2[x] = add(c2[x] + c2[edge[i].v]);
    }
}
inline void solve(int s){
    dijkstra(s);
    memset(flag, 0, sizeof(flag));
    memset(in, 0, sizeof(in));
    for(int i = 1; i <= m; ++i)
        if(dis[e[i].v] == dis[e[i].u] + e[i].w) flag[i] = 1, in[e[i].v]++;
    topo(s);
    for(int i = 1; i <= m; ++i)
        if(flag[i]) ans[i] = add(ans[i] + mul(c1[e[i].u], c2[e[i].v]));
}
signed main(){
#ifndef ONLINE_JUDGE
    freopen("test.in", "r", stdin);
    freopen("test.out", "w", stdout);
#endif
    n = read(), m = read();
    for(int i = 1; i <= m; ++i){
        int u = read(), v = read(), w = read();
        e[i] = (node){u, v, w, 0}, add(u, v, w);
    }
    for(int i = 1; i <= n; ++i) solve(i);
    for(int i = 1; i <= m; ++i) write(ans[i]), puts("");
    return 0;
}
更新于 阅读次数