# 主要思想
WQS 二分也叫带权二分,主要解决一类凸包上的问题。
假设当前你有一个很难计算的凸函数 ,此时你要计算一个指定的 ,如何快速求出呢?
我们可以去二分一个斜率 ,然后计算斜率为 的一条直线落到这个凸壳上时的交点。
设这个交点为 ,再设函数 表示表示斜率为 的直线落到凸壳上时的截距。
那么:
由于 是一条直线,感性理解一下,我们在计算这条直线上的值时就不需要考虑原题中的限制条件了。
于是我们就可以通过二分斜率来快速的找出答案了。
可以通过下面的例题来理解一下。
# 例题
# Description
P4383 [八省联考 2018] 林克卡特树
# Solution
首先转化一下题意,删边再加边显然是我们不想看到的,所以我们可以理解为从这棵树中选择 条不相交的链,使它们的权值和最大。
我们先推出 60pts 的暴力 dp。
设 表示以 为根的子树中选择 条完整链, 点的度数分别为 时的最大权值和。
然后分 3 种情况讨论:
点度数为 0,那么此时 不在链上,于是合并一下子树的权值即可。
点度数为 1, 可能本身是一条链上的端点,或者当前从儿子连上来一条边,如果连上来一条边的话要记得加上边权。
点度数为 2,这时也会有两种情况, 本身就是在一条链的中间,或者 本身是一条链的端点, 也是一条链的端点,然后把 和 连到一起,注意也要加上边权。
最后合并:
答案就是:
这样的复杂度是 的,显然无法通过此题,考虑优化。
不难发现, 值是一个上凸函数,这时 WQS 二分就登场了。
我们二分一下斜率 ,每次暴力计算一下斜率为 的直线的值,这时没有了 条边的限制,我们可以 计算出 值。
然后计算切点的位置,如果在 左边,斜率变小,反之斜率变大。
于是这题就可以在 的复杂度内完成啦。
# Code
#include <bits/stdc++.h> | |
#define ll long long | |
using namespace std; | |
namespace IO{ | |
inline ll read(){ | |
ll x = 0, f = 1; | |
char ch = getchar(); | |
while(!isdigit(ch)) {if(ch == '-') f = -1; ch = getchar();} | |
while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar(); | |
return x * f; | |
} | |
template <typename T> inline void write(T x){ | |
if(x < 0) putchar('-'), x = -x; | |
if(x > 9) write(x / 10); | |
putchar(x % 10 + '0'); | |
} | |
} | |
using namespace IO; | |
const ll N = 3e5 + 10; | |
ll n, k, tot, mid; | |
struct node{ | |
ll v, w; | |
}; | |
vector <node> g[N]; | |
struct data{ | |
ll val, cnt; | |
inline bool operator < (const data &b) const {return val != b.val ? val < b.val : cnt < b.cnt;} | |
inline data operator + (const data &b) const {return (data){val + b.val, cnt + b.cnt};} | |
inline data operator + (ll b) const {return (data){val + b, cnt};} | |
}dp[3][N]; | |
inline data upd(data b) {return (data){b.val - mid, b.cnt + 1};}//y 合并到 x 上,对于 x 来说多了一条链 | |
inline void dfs(ll u, ll fa){ | |
dp[2][u] = max(dp[2][u], (data){-mid, 1}); | |
for(auto y : g[u]){ | |
if(y.v == fa) continue; | |
dfs(y.v, u); | |
dp[2][u] = max(dp[2][u] + dp[0][y.v], upd(dp[1][u] + dp[1][y.v] + y.w)); | |
dp[1][u] = max(dp[1][u] + dp[0][y.v], dp[0][u] + dp[1][y.v] + y.w); | |
dp[0][u] = dp[0][u] + dp[0][y.v]; | |
} | |
dp[0][u] = max(dp[0][u], max(upd(dp[1][u]), dp[2][u])); | |
} | |
inline ll solve(){ | |
ll l = -tot, r = tot, res = 0; | |
while(l <= r){ | |
mid = (l + r) >> 1; | |
memset(dp, 0, sizeof(dp)); | |
dfs(1, 0); | |
if(dp[0][1].cnt >= k) l = mid + 1, res = mid; | |
else r = mid - 1; | |
} | |
memset(dp, 0, sizeof(dp)); | |
mid = res, dfs(1, 0); | |
return res * k + dp[0][1].val; | |
} | |
signed main(){ | |
n = read(), k = read() + 1; | |
for(ll i = 1; i < n; ++i){ | |
ll u = read(), v = read(), w = read(); | |
tot += abs(w); | |
g[u].push_back((node){v, w}), g[v].push_back((node){u, w}); | |
} | |
write(solve()), puts(""); | |
return 0; | |
} |