# Description

Luogu 传送门

vjudge 传送门

不得不说,洛谷的翻译是真的迷,这里再给出一个翻译:

给你一个长度为 nn 的数列,让你把这个数列划分成 k+1k + 1 段。

定义一段区间的价值为区间内数字两两异或的和,即 i=lrj=i+1r(aiaj)\sum\limits_{i = l}^r\sum\limits_{j = i + 1}^r(a_i \oplus a_j)

求一种划分方案使得 k+1k + 1 段价值和最小,只需要输出价值和。

数据范围:1k<n5000,0ai1091 \leq k < n \leq 5000, 0 \leq a_i \leq 10^9

# Solution

首先考虑朴素的 dpdp,定义 dpi,jdp_{i, j} 表示前 ii 个数,划分成 jj 段的最小价值。

转移方程:

dpi,j=min{dpk,j1+w(k+1,i)}dp_{i, j} = \min\{dp_{k, j - 1} + w(k + 1, i)\}

也就是前 kk 个数划分成 j1j - 1 段,k+1ik + 1 \sim i 放到一段里,这一段的价值就是 w(k+1,i)w(k + 1, i)

w(i,j)w(i, j) 可以 O(n2)O(n^2) 预处理出来,然后转移也是 O(n3)O(n^3) 的,总复杂度 O(n3)O(n^3), 无法通过此题。

重点来了,我们要使用四边形不等式优化 dp。

关于四边形不等式优化 dp 的详细讲解请大家左转百度一下吧 QwQ

我们知道,四边形不等式需要满足如下条件:

w(a,c)+w(b,d)w(a,d)+w(b,c)(abcd)w(a, c) + w(b, d) \leq w(a, d) + w(b, c)\ \ (a \leq b \leq c \leq d)

同时,我们知道只需要证明 w(i,j)+w(i+1,j+1)w(i,j+1)+w(i+1,j)w(i, j) + w(i + 1, j + 1) \leq w(i, j + 1) + w(i + 1, j) 即可证明上述条件。

所以我们来推导一下:

w(i,j+1)+w(i+1,j)w(i,j)w(i+1,j+1)=(w(i,j+1)w(i,j))+(w(i+1,j+1)w(i+1,j))=t=ij(aj+1at)t=i+1j(aj+1at)0\begin{aligned} & w(i, j + 1) + w(i + 1, j) - w(i, j) - w(i + 1, j + 1) \\ =& (w(i, j + 1) - w(i, j)) + (w(i + 1, j + 1) - w(i + 1, j)) \\ =& \sum_{t = i}^j (a_{j + 1} \oplus a_t) - \sum_{t = i + 1}^j (a_{j + 1} \oplus a_t) \\ \geq& 0 \end{aligned}

所以我们的转移方程满足四边形不等式。那么这个性质有什么用呢?

众所周知,拥有了这个性质之后我们就有了决策单调性,也就是 pi,j1kpi+1,jp_{i, j - 1} \leq k \leq p_{i + 1, j}

pi,jp_{i, j} 是转移到 dpi,jdp_{i, j} 时的最优决策点,就是上面转移方程里使得 dpi,jdp_{i, j} 取最优值时转移过来的那个 kk

所以我们在对 dpi,jdp_{i, j} 进行转移时只需要枚举 pi,j1pi+1,jp_{i, j - 1} \sim p_{i + 1, j} 即可。

# Code

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 5010;
const ll INF = 1e18;
int n, m;
int a[N], p[N][N];
ll f[N][N], s[N][N];
inline ll sum(int x, int y){
    return s[y][y] - s[x - 1][y] - s[y][x - 1] + s[x - 1][x - 1];
}
int main(){
    scanf("%d%d", &n, &m); m++;
    for(int i = 1; i <= n; ++i) scanf("%d", &a[i]);
    for(int i = 1; i <= n; ++i)
        for(int j = 1; j <= n; ++j)
            s[i][j] = (a[i] ^ a[j]) + s[i - 1][j] + s[i][j - 1] - s[i - 1][j - 1];
    for(int i = 0; i <= n; ++i)
        for(int j = 0; j <= m; ++j) f[i][j] = INF;
    f[0][0] = 0;
    for(int j = 1; j <= m; ++j){
        p[n + 1][j] = n;
        for(int i = n; i >= 1; --i)
            for(int k = p[i][j - 1]; k <= p[i + 1][j]; ++k)
                if(f[i][j] > f[k][j - 1] + sum(k + 1, i))
                    f[i][j] = f[k][j - 1] + sum(k + 1, i), p[i][j] = k;
    }
    printf("%lld\n", f[n][m] >> 1);
    return 0;
}

_EOF_\_EOF\_