首先考虑合法性的判定。由于异或可以看成是不进位的二进制加法,若某一位上出现了至少两个 $1$,那么普通加法在这一位上就会产生进位,从而与异或结果不同。反之,若每一位上至多出现一个 $1$,那么普通加法和异或的结果相同。因此,一条链合法的充要条件就是链上所有数的二进制表示中,每一位至多被一个数占用。

下面令 $V$ 表示值域。对于 $a_i>0$ 的点,由于每个非零数至少占用一个二进制位,而合法链上每一位只能出现一次,所以一条合法链中的非零点数量不会超过 $\log V$。不过,$a_i=0$ 的点比较特殊,可以任意插入到非零点之间。于是我们可以把一条链看成若干个非零点和若干段连续的 $0$ 组成的序列,只需要对非零点维护二进制位冲突,对连续的 $0$ 段整体统计贡献即可。

对于询问 $(u,v)$ 考虑以下三种情况:

  1. 完全位于 $u \to \operatorname{LCA}(u,v)$ 这一侧的链
  2. 完全位于 $v \to \operatorname{LCA}(u,v)$ 这一侧的链
  3. $u \to \operatorname{LCA}(u,v) \to v$ 上的链

对于前两种情况,可以维护出点 $u$ 往上第一次不合法的点 $bad_u$,则以 $u$ 作为下端点的链的方案数为 $dep_u - dep_{bad_u}$。那么对于每一个点,直接前缀和维护答案即可。但是有一个问题就是 $u$ 往上第一次不合法的点 $bad_u$ 可能在 $\operatorname{LCA}(u,v)$ 的上方,因此还需要找到一个 $top_u$ 点表示这是这条链上端点的边界,超过该点的祖先的 $bad_u$ 都不合法。

剩下就是 $top_u \to \operatorname{LCA}(u,v) \to top_v$ 上的链了。可以发现这条链上的非零点的数量不会超过 $\log V$,因此可以暴力枚举这些非零点的位置。对于压缩后的序列,使用双指针维护当前窗口。窗口合法当且仅当所有二进制位出现次数均不超过 $1$。遇到非零点时,更新它占用的二进制位;如果出现冲突,就不断移动左端点直到重新合法。对于一段 $0$,由于它不占用任何二进制位,因此它可以和当前窗口中所有合法前缀自由组合。同时,这段 $0$ 内部也能形成子链,贡献为 $\frac{l(l+1)}{2}$,其中 $l$ 是这段 $0$ 的长度。

至于如何维护 $bad_u$ 和 $top_u$,可以利用 dfs 序和倍增,这里就不再赘述细节。总时间复杂度为 $O(n \log V + q \log^2 V)$。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
constexpr int V = 20;
void solve()
{
int n = read(), q = read(), times = 0;
vector <int> a(n + 1), dep(n + 1), dfn(n + 1), mp(n + 1), nonzero(n + 1), bad(n + 1);
vector <i64> sum(n + 1);
vector <vector <int>> bits(n + 1, vector <int> (V + 1)), f(n + 1, vector <int> (V + 1));
for (int i = 1; i <= n; ++i) a[i] = read();
vector <vector <int>> adj(n + 1);
for (int i = 1; i < n; ++i)
{
int u = read(), v = read();
adj[u].push_back(v); adj[v].push_back(u);
}
auto dfs1 = [&] (auto self, int u, int fa) -> void
{
dep[u] = dep[fa] + 1;
dfn[u] = ++times; mp[times] = u;
f[u][0] = fa;
for (int i = 0; i < V; ++i) f[u][i + 1] = f[f[u][i]][i];
int lst = bad[fa];
for (int i = 0; i < V; ++i)
{
bits[u][i] = bits[fa][i];
if (a[u] >> i & 1)
{
lst = max(lst, bits[fa][i]);
bits[u][i] = dfn[u];
}
}
bad[u] = lst;
sum[u] = sum[fa] + dep[u] - dep[mp[lst]];
nonzero[u] = nonzero[fa];
if (a[u]) nonzero[u] = dfn[u];
for (auto v : adj[u])
{
if (v == fa) continue;
self(self, v, u);
}
};
auto get_LCA = [&] (int u, int v) -> int
{
if (dep[u] < dep[v]) swap(u,v);
for (int i = V; ~i; --i)
{
if (dep[f[u][i]] >= dep[v]) u = f[u][i];
if (u == v) return u;
}
for (int i = V; ~i; --i)
if (f[u][i] != f[v][i]) u = f[u][i],v = f[v][i];
return f[u][0];
};
dfs1(dfs1, 1, 0);
auto get_top = [&] (int u, int lca) -> int
{
int v = u;
for (int i = V; ~i; --i)
if (bad[f[v][i]] >= dfn[lca]) v = f[v][i];
return v;
};
auto chain = [&] (int u, int lca, int v) -> i64
{
auto work = [&] (int u, int bound) -> vector <pii>
{
vector <pii> seg;
while (u && dep[u] > dep[bound])
{
if (a[u])
{
seg.push_back({a[u], 1});
u = f[u][0];
}
else
{
int nxt = mp[nonzero[u]];
if (dep[nxt] < dep[bound]) nxt = bound;
seg.push_back({0, dep[u] - dep[nxt]});
u = nxt;
}
}
return seg;
};
auto A = work(u, f[lca][0]), B = work(v, lca);
reverse(B.begin(), B.end());
auto C = A;
for (int i = 0; i < (int) B.size(); ++i)
{
if (!A.back().first && !i && !(*B.begin()).first) (*(--C.end())).second += (*B.begin()).second;
else C.push_back(B[i]);
}
vector <int> d(V + 1);
i64 res = 0, tot = 0;
for (int l = 0, r = 0; r < (int) C.size(); ++r)
{
auto [op, cnt] = C[r];
auto check = [&] () -> bool
{
for (int i = 0; i < V; ++i)
if (d[i] > 1) return false;
return true;
};
for (int i = 0; i < V; ++i)
if (op >> i & 1) ++d[i];
while (!check())
{
for (int i = 0; i < V; ++i)
if (C[l].first >> i & 1) --d[i];
tot -= C[l++].second;
}
res += tot * cnt;
if (!op) res += 1ll * cnt * (cnt + 1) / 2;
else ++res;
tot += cnt;
}
return res;
};
while (q--)
{
int u = read(), v = read(), lca = get_LCA(u, v);
int topu = get_top(u, lca), topv = get_top(v, lca);
i64 ans = (sum[u] - sum[topu]) + (sum[v] - sum[topv]);
ans += chain(topu, lca, topv); //topu -> lca -> topv
printf("%lld\n", ans);
}
}