# 主要思想
WQS二分也叫带权二分,主要解决一类凸包上的问题。
假设当前你有一个很难计算的凸函数 ,此时你要计算一个指定的 ,如何快速求出呢?
我们可以去二分一个斜率 ,然后计算斜率为 的一条直线落到这个凸壳上时的交点。
设这个交点为 ,再设函数 表示表示斜率为 的直线落到凸壳上时的截距。
那么:
由于 是一条直线,感性理解一下,我们在计算这条直线上的值时就不需要考虑原题中的限制条件了。
于是我们就可以通过二分斜率来快速的找出答案了。
可以通过下面的例题来理解一下。
# 例题
# Description
P4383 [八省联考 2018] 林克卡特树
# Solution
首先转化一下题意,删边再加边显然是我们不想看到的,所以我们可以理解为从这棵树中选择 条不相交的链,使它们的权值和最大。
我们先推出 60pts 的暴力 dp。
设 表示以 为根的子树中选择 条完整链, 点的度数分别为 时的最大权值和。
然后分 3 种情况讨论:
-
点度数为 0,那么此时 不在链上,于是合并一下子树的权值即可。
-
点度数为 1, 可能本身是一条链上的端点,或者当前从儿子连上来一条边,如果连上来一条边的话要记得加上边权。
-
点度数为 2,这时也会有两种情况, 本身就是在一条链的中间,或者 本身是一条链的端点, 也是一条链的端点,然后把 和 连到一起,注意也要加上边权。
最后合并:
答案就是:
这样的复杂度是 的,显然无法通过此题,考虑优化。
不难发现, 值是一个上凸函数,这时WQS二分就登场了。
我们二分一下斜率 ,每次暴力计算一下斜率为 的直线的值,这时没有了 条边的限制,我们可以 计算出 值。
然后计算切点的位置,如果在 左边,斜率变小,反之斜率变大。
于是这题就可以在 的复杂度内完成啦。
# Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
using namespace std;
namespace IO{
inline ll read(){
ll x = 0, f = 1;
char ch = getchar();
while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();}
while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x * f;
}
template <typename T> inline void write(T x){
if(x < 0) putchar('-'), x = -x;
if(x > 9) write(x / 10);
putchar(x % 10 + '0');
}
}
using namespace IO;
const ll N = 3e5 + 10;
ll n, k, tot, mid;
struct node{
ll v, w;
};
vector <node> g[N];
struct data{
ll val, cnt;
inline bool operator < (const data &b) const {return val != b.val ? val < b.val : cnt < b.cnt;}
inline data operator + (const data &b) const {return (data){val + b.val, cnt + b.cnt};}
inline data operator + (ll b) const {return (data){val + b, cnt};}
}dp[3][N];
inline data upd(data b) {return (data){b.val - mid, b.cnt + 1};}// y 合并到 x 上,对于 x 来说多了一条链
inline void dfs(ll u, ll fa){
dp[2][u] = max(dp[2][u], (data){-mid, 1});
for(auto y : g[u]){
if(y.v == fa) continue;
dfs(y.v, u);
dp[2][u] = max(dp[2][u] + dp[0][y.v], upd(dp[1][u] + dp[1][y.v] + y.w));
dp[1][u] = max(dp[1][u] + dp[0][y.v], dp[0][u] + dp[1][y.v] + y.w);
dp[0][u] = dp[0][u] + dp[0][y.v];
}
dp[0][u] = max(dp[0][u], max(upd(dp[1][u]), dp[2][u]));
}
inline ll solve(){
ll l = -tot, r = tot, res = 0;
while(l <= r){
mid = (l + r) >> 1;
memset(dp, 0, sizeof(dp));
dfs(1, 0);
if(dp[0][1].cnt >= k) l = mid + 1, res = mid;
else r = mid - 1;
}
memset(dp, 0, sizeof(dp));
mid = res, dfs(1, 0);
return res * k + dp[0][1].val;
}
signed main(){
n = read(), k = read() + 1;
for(ll i = 1; i < n; ++i){
ll u = read(), v = read(), w = read();
tot += abs(w);
g[u].push_back((node){v, w}), g[v].push_back((node){u, w});
}
write(solve()), puts("");
return 0;
}