# 主要思想

WQS 二分也叫带权二分,主要解决一类凸包上的问题。

假设当前你有一个很难计算的凸函数 f(x)f(x),此时你要计算一个指定的 f(n)f(n),如何快速求出呢?

我们可以去二分一个斜率 kk,然后计算斜率为 kk 的一条直线落到这个凸壳上时的交点。

设这个交点为 (x,f(x))(x, f(x)),再设函数 g(k)g(k) 表示表示斜率为 kk 的直线落到凸壳上时的截距。

那么:

f(x)=g(k)+kxg(k)=f(x)kxf(x) = g(k) + kx\\ g(k) = f(x) - kx

由于 g(k)g(k) 是一条直线,感性理解一下,我们在计算这条直线上的值时就不需要考虑原题中的限制条件了。

于是我们就可以通过二分斜率来快速的找出答案了。

可以通过下面的例题来理解一下。

# 例题

# Description

P4383 [八省联考 2018] 林克卡特树

# Solution

首先转化一下题意,删边再加边显然是我们不想看到的,所以我们可以理解为从这棵树中选择 k+1k + 1 条不相交的链,使它们的权值和最大。

我们先推出 60pts 的暴力 dp。

dp[0/1/2][x][i]dp[0/1/2][x][i] 表示以 xx 为根的子树中选择 ii 条完整链,xx 点的度数分别为 0/1/20/1/2 时的最大权值和。

然后分 3 种情况讨论:

  • xx 点度数为 0,那么此时 xx 不在链上,于是合并一下子树的权值即可。

    dp[0][x][i]=max(dp[0][x][j]+dp[0][y][ij])dp[0][x][i] = \max(dp[0][x][j] + dp[0][y][i - j])

  • xx 点度数为 1,xx 可能本身是一条链上的端点,或者当前从儿子连上来一条边,如果连上来一条边的话要记得加上边权。

    dp[1][x][i]=max(dp[1][x][j]+dp[0][y][ij],dp[0][x][j]+dp[1][y][ij]+wx,y)dp[1][x][i] = \max(dp[1][x][j] + dp[0][y][i - j], dp[0][x][j] + dp[1][y][i - j] + w_{x, y} )

  • xx 点度数为 2,这时也会有两种情况,xx 本身就是在一条链的中间,或者 xx 本身是一条链的端点,yy 也是一条链的端点,然后把 xxyy 连到一起,注意也要加上边权。

    dp[2][x][i]=max(dp[2][x][j]+dp[0][y][ij],dp[1][x][j]+dp[1][y][ij]+wx,y)dp[2][x][i] = \max(dp[2][x][j] + dp[0][y][i - j], dp[1][x][j] + dp[1][y][i - j] + w_{x, y} )

最后合并:

dp[0][x][i]=max(dp[1][x][i1],dp[2][x][i])dp[0][x][i] = \max(dp[1][x][i - 1], dp[2][x][i])

答案就是:dp[0][1][k]dp[0][1][k]

这样的复杂度是 O(nk2)O(nk^2) 的,显然无法通过此题,考虑优化。

不难发现dpdp 值是一个上凸函数,这时 WQS 二分就登场了。

我们二分一下斜率 midmid,每次暴力计算一下斜率为 midmid 的直线的值,这时没有了 kk 条边的限制,我们可以 O(n)O(n) 计算出 dpdp 值。

然后计算切点的位置,如果在 kk 左边,斜率变小,反之斜率变大。

于是这题就可以在 O(nlogk)O(n\log k) 的复杂度内完成啦。

# Code

mark
#include <bits/stdc++.h>
#define ll long long
using namespace std;
namespace IO{
    inline ll read(){
        ll x = 0, f = 1;
        char ch = getchar();
        while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x * f;
    }
    template <typename T> inline void write(T x){
        if(x < 0) putchar('-'), x = -x;
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;
const ll N = 3e5 + 10;
ll n, k, tot, mid;
struct node{
    ll v, w;
};
vector <node> g[N];
struct data{
    ll val, cnt;
    inline bool operator < (const data &b) const {return val != b.val ? val < b.val : cnt < b.cnt;}
    inline data operator + (const data &b) const {return (data){val + b.val, cnt + b.cnt};}
    inline data operator + (ll b) const {return (data){val + b, cnt};}
}dp[3][N];
inline data upd(data b) {return (data){b.val - mid, b.cnt + 1};}//y 合并到 x 上,对于 x 来说多了一条链
inline void dfs(ll u, ll fa){
    dp[2][u] = max(dp[2][u], (data){-mid, 1});
    for(auto y : g[u]){
        if(y.v == fa) continue;
        dfs(y.v, u);
        dp[2][u] = max(dp[2][u] + dp[0][y.v], upd(dp[1][u] + dp[1][y.v] + y.w));
        dp[1][u] = max(dp[1][u] + dp[0][y.v], dp[0][u] + dp[1][y.v] + y.w);
        dp[0][u] = dp[0][u] + dp[0][y.v];
    }
    dp[0][u] = max(dp[0][u], max(upd(dp[1][u]), dp[2][u]));
}
inline ll solve(){
    ll l = -tot, r = tot, res = 0;
    while(l <= r){
        mid = (l + r) >> 1;
        memset(dp, 0, sizeof(dp));
        dfs(1, 0);
        if(dp[0][1].cnt >= k) l = mid + 1, res = mid;
        else r = mid - 1;
    }
    memset(dp, 0, sizeof(dp));
    mid = res, dfs(1, 0);
    return res * k + dp[0][1].val;
}
signed main(){
    n = read(), k = read() + 1;
    for(ll i = 1; i < n; ++i){
        ll u = read(), v = read(), w = read();
        tot += abs(w);
        g[u].push_back((node){v, w}), g[v].push_back((node){u, w});
    }
    write(solve()), puts("");
    return 0;
}

# End