set(并查集 + 时光倒流),树形 dp + 数据分治,dp

# A. 向日葵代沟

直接用 set\text{set} 维护每个断点。

每次新加一个断点时直接二分前后断点位置计算贡献即可。

#include <bits/stdc++.h>
#define int long long
using namespace std;
namespace IO{
    inline int read(){
        int x = 0, f = 1;
        char ch = getchar();
        while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x * f;
    }
    template <typename T> inline void write(T x){
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;
const int N = 1e5 + 10;
int n, m, ans;
int a[N];
set <int> s;
bool vis[N];
signed main(){
    n = read(), m = read();
    s.insert(0), s.insert(n);
    vis[n] = 1;
    while(m--){
        int x = read();
        if(vis[x]) {puts("0"); continue;}
        vis[x] = 1;
        auto l = s.lower_bound(x), r = s.lower_bound(x);
        l--;
        ans = (x - (*l)) * ((*r) - x);
        write(ans), puts("");
        s.insert(x);
    }
    return 0;
}

# B. 树上的僵尸

这道题就复杂多了,考虑 dp\text{dp}

fxf_x 表示两个端点都在 xx 子树内且经过 xx 的合法路径数,gxg_x 表示一端在子树内另一端在子树外的合法路径数。

通过容斥思想,我们用总数减去不合法的来计算答案,不合法的即为所有路径都经过同一个点的情况。

直接枚举点肯定会超时,我们枚举所有路径经过深度最浅的点,那么不合法方案数就变为了总数减去不是最浅的方案数。

总数为 (fi+gi)m(f_i + g_i)^m,不是最浅的方案数为 gimg_i^m,答案为:

ans=(i=1nfi)mi=1n((fi+gi)mgim)ans = (\sum_{i = 1}^nf_i)^m - \sum_{i = 1}^n((f_i + g_i) ^ m - g_i ^m)

再来看 fif_igig_i 如何计算:

  • n1000n \leq 1000:暴力计算,转移的时候就是求编号小于当前点且距离当前点 kk 以内的点的个数,用树状数组维护前缀和即可,复杂度 O(n2logn)O(n^2\log n)

  • k=n1k = n - 1fxf_xxx 子树内任意两点配对个数减去所有儿子自己配对个数,gxg_x(nsizx)×sizx(n - siz_x) \times siz_x

  • 链:因为是一条链,且要经过 xx,相当于有一个端点必须在 xxfxf_xmin{k+1,ni+1}\min\{k + 1, n - i + 1 \}gxg_xxx 向上 kk 步内 fxf_x 的和减去 xxxxkk 级祖先内两两配对的个数。

k=0k = 0 的情况答案就是 nmnn^m - n

然后这题就解决了。

#include <bits/stdc++.h>
#define pb push_back
// #define int long long
using namespace std;
namespace IO{
    inline int read(){
        int x = 0, f = 1;
        char ch = getchar();
        while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x * f;
    }
    template <typename T> inline void write(T x){
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;
const int N = 1e6 + 10;
const int mod = 998244353;
int n, m, k, ans;
vector <int> G[N], vec;
inline int add(int x) {return x >= mod ? x - mod : x;}
inline int sub(int x) {return x < 0 ? x + mod : x;}
inline int mul(int x, int y) {return 1ll * x * y % mod;}
inline int qpow(int a, int b){
    int res = 1;
    while(b){
        if(b & 1) res = mul(res, a);
        a = mul(a, a), b >>= 1;
    }
    return res;
}
int dep[N], sum[N];
int f[N], g[N];
int c[N];
inline void add(int x, int y){
    for(; x <= k; x += x & (-x)) c[x] = add(c[x] + y);
}
inline int qry(int x){
    int res = 0;
    for(; x; x -= x & (-x)) res = add(res + c[x]);
    return res;
}
inline void dfs1(int x, int fa){
    if(!x) return;
    if((dep[x] = dep[fa] + 1) > k) return;
    vec.pb(x);
    for(auto y : G[x]) if(y != fa) dfs1(y, x);
}
inline void dfs2(int x, int fa){
    dep[x] = 0, dfs1(fa, x);
    for(int i = 1; i <= k; ++i) sum[i] = c[i] = 0;
    for(auto v : vec) sum[dep[v]]++;
    vec.clear();
    for(int i = 1; i <= k; ++i) sum[i] += sum[i - 1];
    f[x] = 1, g[x] = sum[k];
    for(auto y : G[x]){
        if(y == fa) continue;
        dfs1(y, x);
        for(auto v : vec){
            f[x] = add(f[x] + 1 + qry(k - dep[v]));
            g[x] = add(g[x] + sum[k - dep[v]]);
        }
        for(auto v : vec) add(dep[v], 1);
        vec.clear();
    }
    for(auto y : G[x]) if(y != fa) dfs2(y, x);
}
inline void solve_pow(){
    dfs2(1, 0);
}
int siz[N];
inline int Sum(int n) {return 1ll * n * (n + 1) / 2 % mod;}
inline void dfs(int x, int fa){
    siz[x] = 1;
    for(auto y : G[x]){
        if(y == fa) continue;
        dfs(y, x), siz[x] += siz[y];
        f[x] = sub(f[x] - Sum(siz[y]));
    }
    f[x] = add(f[x] + Sum(siz[x])), g[x] = mul(siz[x], n - siz[x]);
}
inline void solve_all(){
    dfs(1, 0);
}
inline void solve_chain(){
    for(int i = 1; i <= n; ++i){
        f[i] = min(k + 1, n - i + 1), sum[i] = add(sum[i - 1] + f[i]);
        int j = max(i - k, 1);
        g[i] = sub(sub(sum[i - 1] - sum[j - 1]) - Sum(i - j));
    }
}
signed main(){
#ifndef ONLINE_JUDGE
    freopen("test.in", "r", stdin);
    freopen("test.out", "w", stdout);
#endif
    n = read(), m = read(), k = read();
    for(int i = 1; i < n; ++i){
        int u = read(), v = read();
        G[u].pb(v), G[v].pb(u);
    }
    if(!k) return write(sub(qpow(n, m) - n)), puts(""), 0;
    if(n <= 1000) solve_pow();
    else if(k == n - 1) solve_all();
    else solve_chain();
    int res1 = 0, res2 = 0;
    for(int i = 1; i <= n; ++i){
        res1 = add(res1 + f[i]);
        res2 = add(res2 + sub(qpow(f[i] + g[i], m) - qpow(g[i], m)));
    }
    write(sub(qpow(res1, m) - res2)), puts("");
    return 0;
}

# C. 豌豆射手

考场上一直再调假 dp\text{dp},心态没了 /kk

先想一个比较暴力的做法,枚举豌豆射手的排列,然后尽可能压缩空间的放置。

假设占用的长度为 ss,考虑把后面 dsd - s 格的空位置插到前面 n+1n + 1 个位置中,可以为空,方案数为 (ds+nn)\dbinom {d - s + n}{n}

但是枚举排列显然是过不了的,考虑如何通过 dp\text{dp} 直接计算出占用 ss 的草坪时的方案数。

fi,j,kf_{i, j, k} 表示计算了前 ii 个豌豆射手,形成了 jj 段(每段内无法插入新的豌豆射手,段与段之间还有空位置),使用长度为 kk 草坪时的方案数。

先把豌豆射手按照 rr 从小到大排序,方便转移。

转移分为 3 种情况:

  • 新加的豌豆射手插入到某一段中:fi,j,kf_{i, j, k} 贡献次数为 jj(插到任何一段都行),那么有转移方程:

    fi+1,j,k+ri+1+=fi,j,k×jf_{i + 1, j, k + r_{i + 1}} += f_{i, j, k} \times j

  • 新加的豌豆射手把两段合并到一起:fi,j,kf_{i, j, k} 贡献次数为 j×(j1)j \times (j - 1)(任意两段都有可能合并,且合并顺序有关,即 (a,b)(a, b)(b,a)(b, a) 为不同的合并),有转移方程:

    fi+1,j1,k+2×ri+11+=fi,j,k×j×(j1)f_{i + 1, j - 1, k + 2 \times r_{i + 1} - 1} += f_{i, j, k} \times j \times (j - 1)

  • 新加的自成一段:

    fi+1,j+1,k+1+=fi,j,kf_{i + 1, j + 1, k + 1} += f_{i, j, k}

代码十分简单:

#include <bits/stdc++.h>
#define pb push_back
using namespace std;
namespace IO{
    inline int read(){
        int x = 0, f = 1;
        char ch = getchar();
        while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x * f;
    }
    template <typename T> inline void write(T x){
        if(x < 0) putchar('-'), x = -x;
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;
const int N = 45;
const int S = 1600;
const int M = 1e5 + 10;
const int mod = 1e9 + 7;
int n, d, ans;
int r[N], f[N][N][S + 10];
inline int add(int x) {return x >= mod ? x - mod : x;}
inline void add(int &x, int y) {x = add(x + y);}
inline int sub(int x) {return x < 0 ? x + mod : x;}
inline int mul(int x, int y) {return 1ll * x * y % mod;}
int fac[M], ifac[M];
inline int qpow(int a, int b){
    int res = 1;
    for(; b; a = mul(a, a), b >>= 1)
        if(b & 1) res = mul(res, a);
    return res;
}
inline void init(int n){
    fac[0] = 1;
    for(int i = 1; i <= n; ++i) fac[i] = mul(fac[i - 1], i);
    ifac[n] = qpow(fac[n], mod - 2);
    for(int i = n - 1; i >= 0; --i) ifac[i] = mul(ifac[i + 1], i + 1);
}
inline int C(int n, int m){
    return mul(fac[n], mul(ifac[m], ifac[n - m]));
}
inline void solve(){
    f[0][0][0] = 1;
    for(int i = 0; i <= n; ++i)
        for(int j = 0; j <= i; ++j)
            for(int k = 0; k <= S; ++k){
                int t = r[i + 1];
                if(j >= 1) add(f[i + 1][j][k + t], mul(f[i][j][k], 2 * j));
                if(j >= 2) add(f[i + 1][j - 1][k + 2 * t - 1], mul(f[i][j][k], j * (j - 1)));
                add(f[i + 1][j + 1][k + 1], f[i][j][k]);
            }
}
signed main(){
#ifndef ONLINE_JUDGE
    freopen("test.in", "r", stdin);
    freopen("test.out", "w", stdout);
#endif
    n = read(), d = read(), init(M - 10);
    for(int i = 1; i <= n; ++i) r[i] = read();
    sort(r + 1, r + 1 + n);
    solve();
    for(int i = 0; i <= min(d, S); ++i)
        ans = add(ans + mul(C(d - i + n, n), f[n][1][i]));
    write(ans), puts("");
    return 0;
}
更新于 阅读次数