1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
| constexpr int V = 20; void solve() { int n = read(), q = read(), times = 0; vector <int> a(n + 1), dep(n + 1), dfn(n + 1), mp(n + 1), nonzero(n + 1), bad(n + 1); vector <i64> sum(n + 1); vector <vector <int>> bits(n + 1, vector <int> (V + 1)), f(n + 1, vector <int> (V + 1)); for (int i = 1; i <= n; ++i) a[i] = read(); vector <vector <int>> adj(n + 1); for (int i = 1; i < n; ++i) { int u = read(), v = read(); adj[u].push_back(v); adj[v].push_back(u); } auto dfs1 = [&] (auto self, int u, int fa) -> void { dep[u] = dep[fa] + 1; dfn[u] = ++times; mp[times] = u; f[u][0] = fa; for (int i = 0; i < V; ++i) f[u][i + 1] = f[f[u][i]][i]; int lst = bad[fa]; for (int i = 0; i < V; ++i) { bits[u][i] = bits[fa][i]; if (a[u] >> i & 1) { lst = max(lst, bits[fa][i]); bits[u][i] = dfn[u]; } } bad[u] = lst; sum[u] = sum[fa] + dep[u] - dep[mp[lst]]; nonzero[u] = nonzero[fa]; if (a[u]) nonzero[u] = dfn[u]; for (auto v : adj[u]) { if (v == fa) continue; self(self, v, u); } }; auto get_LCA = [&] (int u, int v) -> int { if (dep[u] < dep[v]) swap(u,v); for (int i = V; ~i; --i) { if (dep[f[u][i]] >= dep[v]) u = f[u][i]; if (u == v) return u; } for (int i = V; ~i; --i) if (f[u][i] != f[v][i]) u = f[u][i],v = f[v][i]; return f[u][0]; }; dfs1(dfs1, 1, 0); auto get_top = [&] (int u, int lca) -> int { int v = u; for (int i = V; ~i; --i) if (bad[f[v][i]] >= dfn[lca]) v = f[v][i]; return v; }; auto chain = [&] (int u, int lca, int v) -> i64 { auto work = [&] (int u, int bound) -> vector <pii> { vector <pii> seg; while (u && dep[u] > dep[bound]) { if (a[u]) { seg.push_back({a[u], 1}); u = f[u][0]; } else { int nxt = mp[nonzero[u]]; if (dep[nxt] < dep[bound]) nxt = bound; seg.push_back({0, dep[u] - dep[nxt]}); u = nxt; } } return seg; }; auto A = work(u, f[lca][0]), B = work(v, lca); reverse(B.begin(), B.end()); auto C = A; for (int i = 0; i < (int) B.size(); ++i) { if (!A.back().first && !i && !(*B.begin()).first) (*(--C.end())).second += (*B.begin()).second; else C.push_back(B[i]); } vector <int> d(V + 1); i64 res = 0, tot = 0; for (int l = 0, r = 0; r < (int) C.size(); ++r) { auto [op, cnt] = C[r]; auto check = [&] () -> bool { for (int i = 0; i < V; ++i) if (d[i] > 1) return false; return true; }; for (int i = 0; i < V; ++i) if (op >> i & 1) ++d[i]; while (!check()) { for (int i = 0; i < V; ++i) if (C[l].first >> i & 1) --d[i]; tot -= C[l++].second; } res += tot * cnt; if (!op) res += 1ll * cnt * (cnt + 1) / 2; else ++res; tot += cnt; } return res; }; while (q--) { int u = read(), v = read(), lca = get_LCA(u, v); int topu = get_top(u, lca), topv = get_top(v, lca); i64 ans = (sum[u] - sum[topu]) + (sum[v] - sum[topv]); ans += chain(topu, lca, topv); printf("%lld\n", ans); } }
|