# 主要思想

WQS二分也叫带权二分,主要解决一类凸包上的问题。

假设当前你有一个很难计算的凸函数 f(x)f(x),此时你要计算一个指定的 f(n)f(n),如何快速求出呢?

我们可以去二分一个斜率 kk,然后计算斜率为 kk 的一条直线落到这个凸壳上时的交点。

设这个交点为 (x,f(x))(x, f(x)),再设函数 g(k)g(k) 表示表示斜率为 kk 的直线落到凸壳上时的截距。

那么:

f(x)=g(k)+kxg(k)=f(x)kxf(x) = g(k) + kx\\ g(k) = f(x) - kx

由于 g(k)g(k) 是一条直线,感性理解一下,我们在计算这条直线上的值时就不需要考虑原题中的限制条件了。

于是我们就可以通过二分斜率来快速的找出答案了。

可以通过下面的例题来理解一下。

# 例题

# Description

P4383 [八省联考 2018] 林克卡特树

# Solution

首先转化一下题意,删边再加边显然是我们不想看到的,所以我们可以理解为从这棵树中选择 k+1k + 1 条不相交的链,使它们的权值和最大。

我们先推出 60pts 的暴力 dp。

dp[0/1/2][x][i]dp[0/1/2][x][i] 表示以 xx 为根的子树中选择 ii 条完整链,xx 点的度数分别为 0/1/20/1/2 时的最大权值和。

然后分 3 种情况讨论:

  • xx 点度数为 0,那么此时 xx 不在链上,于是合并一下子树的权值即可。

    dp[0][x][i]=max(dp[0][x][j]+dp[0][y][ij])dp[0][x][i] = \max(dp[0][x][j] + dp[0][y][i - j])

  • xx 点度数为 1,xx 可能本身是一条链上的端点,或者当前从儿子连上来一条边,如果连上来一条边的话要记得加上边权。

    dp[1][x][i]=max(dp[1][x][j]+dp[0][y][ij],dp[0][x][j]+dp[1][y][ij]+wx,y)dp[1][x][i] = \max(dp[1][x][j] + dp[0][y][i - j], dp[0][x][j] + dp[1][y][i - j] + w_{x, y} )

  • xx 点度数为 2,这时也会有两种情况,xx 本身就是在一条链的中间,或者 xx 本身是一条链的端点,yy 也是一条链的端点,然后把 xxyy 连到一起,注意也要加上边权。

    dp[2][x][i]=max(dp[2][x][j]+dp[0][y][ij],dp[1][x][j]+dp[1][y][ij]+wx,y)dp[2][x][i] = \max(dp[2][x][j] + dp[0][y][i - j], dp[1][x][j] + dp[1][y][i - j] + w_{x, y} )

最后合并:

dp[0][x][i]=max(dp[1][x][i1],dp[2][x][i])dp[0][x][i] = \max(dp[1][x][i - 1], dp[2][x][i])

答案就是:dp[0][1][k]dp[0][1][k]

这样的复杂度是 O(nk2)O(nk^2) 的,显然无法通过此题,考虑优化。

不难发现dpdp 值是一个上凸函数,这时WQS二分就登场了。

我们二分一下斜率 midmid,每次暴力计算一下斜率为 midmid 的直线的值,这时没有了 kk 条边的限制,我们可以 O(n)O(n) 计算出 dpdp 值。

然后计算切点的位置,如果在 kk 左边,斜率变小,反之斜率变大。

于是这题就可以在 O(nlogk)O(n\log k) 的复杂度内完成啦。

# Code

mark
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <bits/stdc++.h>
#define ll long long

using namespace std;

namespace IO{
inline ll read(){
ll x = 0, f = 1;
char ch = getchar();
while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x * f;
}

template <typename T> inline void write(T x){
if(x < 0) putchar('-'), x = -x;
if(x > 9) write(x / 10);
putchar(x % 10 + '0');
}
}
using namespace IO;

const ll N = 3e5 + 10;
ll n, k, tot, mid;
struct node{
ll v, w;
};
vector <node> g[N];

struct data{
ll val, cnt;
inline bool operator < (const data &b) const {return val != b.val ? val < b.val : cnt < b.cnt;}
inline data operator + (const data &b) const {return (data){val + b.val, cnt + b.cnt};}
inline data operator + (ll b) const {return (data){val + b, cnt};}
}dp[3][N];

inline data upd(data b) {return (data){b.val - mid, b.cnt + 1};}// y 合并到 x 上,对于 x 来说多了一条链

inline void dfs(ll u, ll fa){
dp[2][u] = max(dp[2][u], (data){-mid, 1});
for(auto y : g[u]){
if(y.v == fa) continue;
dfs(y.v, u);
dp[2][u] = max(dp[2][u] + dp[0][y.v], upd(dp[1][u] + dp[1][y.v] + y.w));
dp[1][u] = max(dp[1][u] + dp[0][y.v], dp[0][u] + dp[1][y.v] + y.w);
dp[0][u] = dp[0][u] + dp[0][y.v];
}
dp[0][u] = max(dp[0][u], max(upd(dp[1][u]), dp[2][u]));
}

inline ll solve(){
ll l = -tot, r = tot, res = 0;
while(l <= r){
mid = (l + r) >> 1;
memset(dp, 0, sizeof(dp));
dfs(1, 0);
if(dp[0][1].cnt >= k) l = mid + 1, res = mid;
else r = mid - 1;
}
memset(dp, 0, sizeof(dp));
mid = res, dfs(1, 0);
return res * k + dp[0][1].val;
}

signed main(){
n = read(), k = read() + 1;
for(ll i = 1; i < n; ++i){
ll u = read(), v = read(), w = read();
tot += abs(w);
g[u].push_back((node){v, w}), g[v].push_back((node){u, w});
}
write(solve()), puts("");
return 0;
}

# End