# Description

Luogu传送门

# Solution

题目让我们求 kk 个串 LCP\text{LCP} 之和最大,所以考虑建 trie\text{trie} 树,然后我们就可以跑树形 dp\text{dp}

dpu,idp_{u, i} 表示在以 uu 为根的子树中,选取 ii 个结束节点,两两 LCP\text{LCP} 之和最大是多少。

转移的时候为了避免重复转移,我们只加上 xx 点到其父亲的贡献,即:

dpx,i=max{dpx,ij+dpy,j+(n2)}dp_{x, i} = \max\{dp_{x, i - j} + dp_{y, j} + \dbinom{n}{2}\}

就是 xx 子树内所有被选的点两两组成点对都会对 xfaxx - fa_x 这条边造成贡献,所以贡献总数就是 (n2)\dbinom{n}{2}

但是暴力 dp\text{dp} 无法通过此题,考虑如何优化。

不难发现,只有结束节点,以及任意两两结束节点的 lca\text{lca} 对答案会有贡献,所以建个虚树在上面跑 dp\text{dp} 即可。

建完虚树之后,xxfaxfa_x 的距离就不再是 1 了,所以贡献要乘上 depxdepfa[x]dep_x - dep_{fa[x]},也就是:

dpx,i=max{dpx,ij+dpy,j+(n2)×(depxdepfa[x])}dp_{x, i} = \max\{dp_{x, i - j} + dp_{y, j} + \dbinom{n}{2} \times (dep_x - dep_{fa[x]})\}

建虚树的过程也不用那么麻烦,我们只需要 dfs\text{dfs} 一遍,如果当前点 xx 为根或是结束节点,那么保留下来,如果 xx 子树内结束节点个数大于 1 也保留下来(这样已经可以通过了)。

# Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <bits/stdc++.h>
#define ll long long
#define pb push_back

using namespace std;

namespace IO{
inline int read(){
int x = 0;
char ch = getchar();
while(!isdigit(ch)) ch = getchar();
while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x;
}

template <typename T> inline void write(T x){
if(x > 9) write(x / 10);
putchar(x % 10 + '0');
}
}
using namespace IO;

const int N = 4010;
const int M = 1e6;
int n, m;
int ch[M][27], tot = 1;
int End[M], siz[M], dep[M];
char s[N];

inline void insert(char s[]){
int len = strlen(s), now = 1;
for(int i = 0; i < len; ++i){
int c = s[i] - 'a';
if(!ch[now][c]) ch[now][c] = ++tot, siz[now]++;
now = ch[now][c];
}
End[now]++;
}

vector <int> g[N];
int cnt, num[N];

inline void dfs(int x, int deep, int lst){
if(x == 1 || siz[x] > 1 || End[x])
dep[++cnt] = deep, g[lst].pb(cnt), lst = cnt, num[cnt] = End[x];
for(int i = 0; i < 26; ++i)
if(ch[x][i]) dfs(ch[x][i], deep + 1, lst);
}

int dp[N][N];

inline void solve(int x, int fa){
dp[x][0] = 0;
for(auto y : g[x]){
solve(y, x), num[x] += num[y];
for(int j = min(num[x], m); j; --j)
for(int k = min(num[y], j); k; --k)
dp[x][j] = max(dp[x][j], dp[x][j - k] + dp[y][k]);
}
for(int i = 1; i <= num[x]; ++i)
dp[x][i] += (dep[x] - dep[fa]) * (i * (i - 1) / 2);
}

signed main(){
n = read(), m = read();
for(int i = 1; i <= n; ++i) scanf("%s", s), insert(s);
dfs(1, 0, 0), solve(1, 0);
write(dp[1][m]), puts("");
return 0;
}