# Description

Luogu 传送门

# Solution

题目让我们求 kk 个串 LCP\text{LCP} 之和最大,所以考虑建 trie\text{trie} 树,然后我们就可以跑树形 dp\text{dp}

dpu,idp_{u, i} 表示在以 uu 为根的子树中,选取 ii 个结束节点,两两 LCP\text{LCP} 之和最大是多少。

转移的时候为了避免重复转移,我们只加上 xx 点到其父亲的贡献,即:

dpx,i=max{dpx,ij+dpy,j+(n2)}dp_{x, i} = \max\{dp_{x, i - j} + dp_{y, j} + \dbinom{n}{2}\}

就是 xx 子树内所有被选的点两两组成点对都会对 xfaxx - fa_x 这条边造成贡献,所以贡献总数就是 (n2)\dbinom{n}{2}

但是暴力 dp\text{dp} 无法通过此题,考虑如何优化。

不难发现,只有结束节点,以及任意两两结束节点的 lca\text{lca} 对答案会有贡献,所以建个虚树在上面跑 dp\text{dp} 即可。

建完虚树之后,xxfaxfa_x 的距离就不再是 1 了,所以贡献要乘上 depxdepfa[x]dep_x - dep_{fa[x]},也就是:

dpx,i=max{dpx,ij+dpy,j+(n2)×(depxdepfa[x])}dp_{x, i} = \max\{dp_{x, i - j} + dp_{y, j} + \dbinom{n}{2} \times (dep_x - dep_{fa[x]})\}

建虚树的过程也不用那么麻烦,我们只需要 dfs\text{dfs} 一遍,如果当前点 xx 为根或是结束节点,那么保留下来,如果 xx 子树内结束节点个数大于 1 也保留下来(这样已经可以通过了)。

# Code

#include <bits/stdc++.h>
#define ll long long
#define pb push_back
using namespace std;
namespace IO{
    inline int read(){
        int x = 0;
        char ch = getchar();
        while(!isdigit(ch)) ch = getchar();
        while(isdigit(ch)) x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
        return x;
    }
    template <typename T> inline void write(T x){
        if(x > 9) write(x / 10);
        putchar(x % 10 + '0');
    }
}
using namespace IO;
const int N = 4010;
const int M = 1e6;
int n, m;
int ch[M][27], tot = 1;
int End[M], siz[M], dep[M];
char s[N];
inline void insert(char s[]){
    int len = strlen(s), now = 1;
    for(int i = 0; i < len; ++i){
        int c = s[i] - 'a';
        if(!ch[now][c]) ch[now][c] = ++tot, siz[now]++;
        now = ch[now][c];
    }
    End[now]++;
}
vector <int> g[N];
int cnt, num[N];
inline void dfs(int x, int deep, int lst){
    if(x == 1 || siz[x] > 1 || End[x])
        dep[++cnt] = deep, g[lst].pb(cnt), lst = cnt, num[cnt] = End[x];
    for(int i = 0; i < 26; ++i)
        if(ch[x][i]) dfs(ch[x][i], deep + 1, lst);
}
int dp[N][N];
inline void solve(int x, int fa){
    dp[x][0] = 0;
    for(auto y : g[x]){
        solve(y, x), num[x] += num[y];
        for(int j = min(num[x], m); j; --j)
            for(int k = min(num[y], j); k; --k)
                dp[x][j] = max(dp[x][j], dp[x][j - k] + dp[y][k]);
    }
    for(int i = 1; i <= num[x]; ++i)
        dp[x][i] += (dep[x] - dep[fa]) * (i * (i - 1) / 2);
}
signed main(){
    n = read(), m = read();
    for(int i = 1; i <= n; ++i) scanf("%s", s), insert(s);
    dfs(1, 0, 0), solve(1, 0);
    write(dp[1][m]), puts("");
    return 0;
}