根据贪心,不难想到每次会把最长队伍末尾的那辆车移动到最短队伍的末尾。但由于 $k$ 的存在,会导致一些冗余移动的存在。设需要挪动 $C$ 辆车,则怒气值可以表示为 $f(C) + kC$,其中 $f(C)$ 是排队所产生的怒气值,$kC$ 为变道产生的额外怒气值。仔细分析以后,可以发现这是一个凸函数,因此考虑三分答案。
一开始想要三分需要挪车的最短长度 $y$,但是不能忽略 $k$ 的影响,有些队伍的长度虽然 $> y$,但挪动不移动会更优。于是三分挪动车辆的数量才是最优的。
具体来说,可以枚举哪些队伍的车辆会减少/增加。若现在考虑会减少的队伍的车辆,给 $a_i$ 排序后,设当前最长队伍的车辆数为 $x$,次长的为 $y$ ($x \neq y$),然后长度为 $x,y$ 的队伍的数量分别为 $f_x,f_y$。若共需要移动 $C$ 辆车,则有两种情况:
$C \ge (x - y) \times f_x$,也就是说长度为 $x$ 的车可以直接变为 $y$,$C \leftarrow C - (x - y) \times f_x; f_y \leftarrow f_x + f_y; f_x \leftarrow 0$。
$C < (x - y) \times f_x$,此时会产生新的队伍长度,也就是 $C \leftarrow 0; f_{x - \lfloor\frac{C}{f_x}\rfloor - 1} \leftarrow f_{x - \lfloor\frac{C}{f_x}\rfloor - 1} + C \bmod f_x; \leftarrow f_{x - \lfloor\frac{C}{f_x}\rfloor} + (f_x - C \bmod f_x)$。
可以发现最后队伍长度的种类数不会超过 $n + 2$,因此这是 $O(n)$ 的。考虑增加的队伍的车辆同理,用 STL 来写会简单一点。但是由于多了一支 $\log$,实测会超时:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
| ll tot = sum * k,res = sum,number = sum; set <int> s;map <int,int> bg,sm; s.insert (-1e9); for (int i = 1;i <= n;++i) s.insert (a[i]),++bg[a[i]]; while (sum) { int x = *(--s.end ()),num = bg[x];s.erase (x); int y = *(--s.end ()); if (sum >= 1ll * (x - y) * num) { sum -= 1ll * (x - y) * num; bg[y] += num;bg[x] = 0; } else { bg[x] = 0; int tmp = sum % num; if (tmp) bg[x - sum / num - 1] += tmp; bg[x - sum / num] += num - tmp; sum = 0; } } s.clear (); for (auto [x,num] : bg) if (num) s.insert (x),sm[x] = num; s.insert (1e9); while (res) { int x = *s.begin (),num = sm[x];s.erase (x); int y = *s.begin (); if (res >= 1ll * (y - x) * num) { res -= 1ll * (y - x) * num; sm[y] += num;sm[x] = 0; } else { sm[x] = 0; int tmp = res % num; if (tmp) sm[x + res / num + 1] += tmp; sm[x + res / num] += num - tmp; res = 0; } } for (auto [x,num] : sm) tot += 1ll * x * (x + 1) / 2 * num; return tot; };
|
再次思考可以发现 STL 的 $\log$ 完全是多余的,可以通过数组来替代,但需要小心清空与去重的问题。最后的 AC 代码如下,时间复杂度 $O(n \log n)$:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
| #include <bits/stdc++.h> #define init(x) memset (x,0,sizeof (x)) #define ll long long #define ull unsigned long long #define INF 2e18 #define pii pair <int,int> using namespace std; const int MAX = 2e5 + 5; const int MOD = 1e9 + 7; inline int read (); int a[MAX],b[MAX]; vector <int> bg (1000001,0),sm (1000001,0); void solve () { int n = read (),k = read ();ll ans = INF; for (int i = 1;i <= n;++i) a[i] = read (); sort (a + 1,a + 1 + n); auto check = [&] (ll sum) -> ll { ll tot = sum * k,res = sum;int cnt = 0; vector <int> p; for (int i = 1;i <= n;++i) p.push_back (a[i]); for (int i = 1;i <= n;++i) { if (!bg[a[i]]) b[++cnt] = a[i]; ++bg[a[i]]; } b[0] = -1e9; while (sum > 0) { int x = b[cnt--],num = bg[x]; int y = b[cnt]; if (sum >= 1ll * (x - y) * num) { sum -= 1ll * (x - y) * num; bg[y] += num;bg[x] = 0; } else { bg[x] = 0; int tmp = sum % num; bg[x - sum / num] += num - tmp,p.push_back (x - sum / num); if (tmp) bg[x - sum / num - 1] += tmp,p.push_back (x - sum / num - 1); sum = 0; } } cnt = 0; for (auto v : p) if (bg[v]) b[++cnt] = v,sm[v] = bg[v],bg[v] = 0; p.clear (); for (int i = 1;i <= cnt;++i) p.push_back (b[i]); b[++cnt] = 1e9;cnt = 1; while (res > 0) { int x = b[cnt++],num = sm[x]; int y = b[cnt]; if (res >= 1ll * (y - x) * num) { res -= 1ll * (y - x) * num; sm[y] += num;sm[x] = 0; } else { sm[x] = 0; int tmp = res % num; if (tmp) sm[x + res / num + 1] += tmp,p.push_back (x + res / num + 1); sm[x + res / num] += num - tmp,p.push_back (x + res / num); res = 0; } } for (auto v : p) tot += 1ll * v * (v + 1) / 2 * sm[v],sm[v] = 0; return tot; }; ll l = 0,r = accumulate (a + 1,a + n + 1,0ll); while (l < r) { ll midl = l + (r - l) / 3,midr = r - (r - l) / 3; ll v1 = check (midl),v2 = check (midr); ans = min (ans,min (v1,v2)); if (v1 <= v2) r = midr - 1; else l = midl + 1; } printf ("%lld\n",ans); } int main () { int t = read (); while (t--) solve (); return 0; } inline int read () { int s = 0;int f = 1; char ch = getchar (); while ((ch < '0' || ch > '9') && ch != EOF) { if (ch == '-') f = -1; ch = getchar (); } while (ch >= '0' && ch <= '9') { s = s * 10 + ch - '0'; ch = getchar (); } return s * f; }
|