根据贪心,不难想到每次会把最长队伍末尾的那辆车移动到最短队伍的末尾。但由于 $k$ 的存在,会导致一些冗余移动的存在。设需要挪动 $C$ 辆车,则怒气值可以表示为 $f(C) + kC$,其中 $f(C)$ 是排队所产生的怒气值,$kC$ 为变道产生的额外怒气值。仔细分析以后,可以发现这是一个凸函数,因此考虑三分答案。

一开始想要三分需要挪车的最短长度 $y$,但是不能忽略 $k$ 的影响,有些队伍的长度虽然 $> y$,但挪动不移动会更优。于是三分挪动车辆的数量才是最优的。

具体来说,可以枚举哪些队伍的车辆会减少/增加。若现在考虑会减少的队伍的车辆,给 $a_i$ 排序后,设当前最长队伍的车辆数为 $x$,次长的为 $y$ ($x \neq y$),然后长度为 $x,y$ 的队伍的数量分别为 $f_x,f_y$。若共需要移动 $C$ 辆车,则有两种情况:

  • $C \ge (x - y) \times f_x$,也就是说长度为 $x$ 的车可以直接变为 $y$,$C \leftarrow C - (x - y) \times f_x; f_y \leftarrow f_x + f_y; f_x \leftarrow 0$。

  • $C < (x - y) \times f_x$,此时会产生新的队伍长度,也就是 $C \leftarrow 0; f_{x - \lfloor\frac{C}{f_x}\rfloor - 1} \leftarrow f_{x - \lfloor\frac{C}{f_x}\rfloor - 1} + C \bmod f_x; \leftarrow f_{x - \lfloor\frac{C}{f_x}\rfloor} + (f_x - C \bmod f_x)$。

可以发现最后队伍长度的种类数不会超过 $n + 2$,因此这是 $O(n)$ 的。考虑增加的队伍的车辆同理,用 STL 来写会简单一点。但是由于多了一支 $\log$,实测会超时:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
ll tot = sum * k,res = sum,number = sum;
set <int> s;map <int,int> bg,sm;
s.insert (-1e9);
for (int i = 1;i <= n;++i) s.insert (a[i]),++bg[a[i]];
while (sum)
{
int x = *(--s.end ()),num = bg[x];s.erase (x);
int y = *(--s.end ());
if (sum >= 1ll * (x - y) * num)
{
sum -= 1ll * (x - y) * num;
bg[y] += num;bg[x] = 0;
}
else
{
bg[x] = 0;
int tmp = sum % num;
if (tmp) bg[x - sum / num - 1] += tmp;
bg[x - sum / num] += num - tmp;
sum = 0;
}
}
s.clear ();
for (auto [x,num] : bg)
if (num) s.insert (x),sm[x] = num;
s.insert (1e9);
while (res)
{
int x = *s.begin (),num = sm[x];s.erase (x);
int y = *s.begin ();
if (res >= 1ll * (y - x) * num)
{
res -= 1ll * (y - x) * num;
sm[y] += num;sm[x] = 0;
}
else
{
sm[x] = 0;
int tmp = res % num;
if (tmp) sm[x + res / num + 1] += tmp;
sm[x + res / num] += num - tmp;
res = 0;
}
}
for (auto [x,num] : sm) tot += 1ll * x * (x + 1) / 2 * num;
return tot;
};

再次思考可以发现 STL 的 $\log$ 完全是多余的,可以通过数组来替代,但需要小心清空与去重的问题。最后的 AC 代码如下,时间复杂度 $O(n \log n)$:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#include <bits/stdc++.h>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 2e18
#define pii pair <int,int>
using namespace std;
const int MAX = 2e5 + 5;
const int MOD = 1e9 + 7;
inline int read ();
int a[MAX],b[MAX];
vector <int> bg (1000001,0),sm (1000001,0);
void solve ()
{
int n = read (),k = read ();ll ans = INF;
for (int i = 1;i <= n;++i) a[i] = read ();
sort (a + 1,a + 1 + n);
auto check = [&] (ll sum) -> ll
{
ll tot = sum * k,res = sum;int cnt = 0;
vector <int> p;
for (int i = 1;i <= n;++i) p.push_back (a[i]);
for (int i = 1;i <= n;++i)
{
if (!bg[a[i]]) b[++cnt] = a[i];
++bg[a[i]];
}
b[0] = -1e9;
while (sum > 0)
{
int x = b[cnt--],num = bg[x];
int y = b[cnt];
if (sum >= 1ll * (x - y) * num)
{
sum -= 1ll * (x - y) * num;
bg[y] += num;bg[x] = 0;
}
else
{
bg[x] = 0;
int tmp = sum % num;
bg[x - sum / num] += num - tmp,p.push_back (x - sum / num);
if (tmp) bg[x - sum / num - 1] += tmp,p.push_back (x - sum / num - 1);
sum = 0;
}
}
cnt = 0;
for (auto v : p)
if (bg[v]) b[++cnt] = v,sm[v] = bg[v],bg[v] = 0;
p.clear ();
for (int i = 1;i <= cnt;++i) p.push_back (b[i]);
b[++cnt] = 1e9;cnt = 1;
while (res > 0)
{
int x = b[cnt++],num = sm[x];
int y = b[cnt];
if (res >= 1ll * (y - x) * num)
{
res -= 1ll * (y - x) * num;
sm[y] += num;sm[x] = 0;
}
else
{
sm[x] = 0;
int tmp = res % num;
if (tmp) sm[x + res / num + 1] += tmp,p.push_back (x + res / num + 1);
sm[x + res / num] += num - tmp,p.push_back (x + res / num);
res = 0;
}
}
for (auto v : p) tot += 1ll * v * (v + 1) / 2 * sm[v],sm[v] = 0;
return tot;
};
ll l = 0,r = accumulate (a + 1,a + n + 1,0ll);
while (l < r)
{
ll midl = l + (r - l) / 3,midr = r - (r - l) / 3;
ll v1 = check (midl),v2 = check (midr);
ans = min (ans,min (v1,v2));
if (v1 <= v2) r = midr - 1;
else l = midl + 1;
}
printf ("%lld\n",ans);
}
int main ()
{
int t = read ();
while (t--) solve ();
return 0;
}
inline int read ()
{
int s = 0;int f = 1;
char ch = getchar ();
while ((ch < '0' || ch > '9') && ch != EOF)
{
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9')
{
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * f;
}