树链剖分

P3384 【模板】轻重链剖分/树链剖分

  1. 作用
  • 维护树上路径的相关信息。

  • 常与线段树相结合。

  1. 性质
  • 所有节点都属于且仅属于一条重链,重链将树完全剖分。

  • 重链与子树内的 $\texttt{dfs}$ 序连续。【这一个性质非常有用

  • 每一条路径最多被拆分成 $\log n$ 条重链(向下经过一条轻边时,子树大小至少除以 $2$)。

  1. 一些定义
  • f[x] 节点 $x$ 的父亲。

  • sz[x] 节点 $x$ 对应的子树大小。

  • dep[x] 节点 $x$ 的深度(假定编号为 $r$ 的节点深度为 $1$)。

  • dfn[x] 节点 $x$ 的 $\texttt{dfs}$ 序。

  • hson[x] 节点 $x$ 所对应的重儿子。

  • top[x] 节点 $x$ 所在的重链的顶部节点。

  • rk[x] $\texttt{dfs}$ 序所对应的节点编号,即 rk[dfn[x]] = x

  1. 两个 $\texttt{dfs}$
  • 第一个分别求出 fa[x] sz[x] dep[x] hson[x]

具体代码如下 :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void dfs1 (int u,int fa)
{
sz[u] = 1;
dep[u] = dep[fa] + 1;
f[u] = fa;
for (int i = head[u];i;i = nxt[i])
{
int v = to[i];
if (v == fa) continue;
dfs1 (v,u);
sz[u] += sz[v];
if (sz[hson[u]] < sz[v]) hson[u] = v;//子块大的便是重儿子
}
}
  • 第二个分别求出 top[x] dfn[x] rk[x]

具体代码如下 :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
void dfs2 (int u,int fa)
{
if (hson[u])
{
top[hson[u]] = top[u];
dfn[hson[u]] = ++cnt;
rk[cnt] = hson[u];
dfs2 (hson[u],u);//优先对重儿子进行剖分,从而保证重链上的 dfn 序连续
}
for (int i = head[u];i;i = nxt[i])
{
int v = to[i];
if (v == fa || top[v]) continue;
top[v] = v;//单独以 v 作为重链的顶端
dfn[v] = ++cnt;
rk[cnt] = v;
dfs2 (v,v);
}
}
  1. 树上两点路径权值和的修改与查询

每次选择深度大的链往上跳,直到两点在同一条链上。由于链上的 $\texttt{dfn}$ 连续,所以直接再用线段树(或树状数组)进行维护即可。

以查询为例 :

1
2
3
4
5
6
7
8
9
10
11
12
void upd (int x,int y,int v)// x - y 的最短路径上,所有点的权值均加上 v
{
int fx = top[x],fy = top[y];
while (fx != fy)
{
if (dep[fx] < dep[fy]) swap (fx,fy),swap (x,y);//选择深度大的向上跳
modify (1,1,cnt,dfn[fx],dfn[x],v);//某一段链的更新 注意 f[x] 的 dfn 序更小
x = f[fx],fx = top[x];
}
if (dep[x] > dep[y]) swap (x,y);
modify (1,1,cnt,dfn[x],dfn[y],v);//最后 x 与 y 在同一条链上
}
  1. 完整代码
点击展开完整代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
using namespace std;
const int MAX = 1e5 + 5;
const int MOD = 1e9 + 7;
inline int read ();
int n,m,r,p,a[MAX];
int dcnt,head[MAX << 1],nxt[MAX << 1],to[MAX << 1];
ll tree[MAX << 2],tmp[MAX << 2];
int cnt,f[MAX],dfn[MAX],sz[MAX],dep[MAX],top[MAX],hson[MAX],rk[MAX];
void add (int u,int v);
void build (int cur,int l,int r);
void pushdown (int cur,int l,int r);
void modify (int cur,int l,int r,int x,int y,int v);
ll query (int cur,int l,int r,int x,int y);
void dfs1 (int u,int fa);
void dfs2 (int u,int fa);
void upd (int x,int y,int v);
ll ask (int x,int y);
int main ()
{
//freopen (".in","r",stdin);
//freopen (".out","w",stdout);
n = read ();m = read ();r = read ();p = read ();
for (int i = 1;i <= n;++i) a[i] = read ();
for (int i = 1;i < n;++i)
{
int x = read (),y = read ();
add (x,y);add (y,x);
}
dfs1 (r,-1);
dfn[r] = ++cnt;rk[1] = top[r] = r;
dfs2 (r,-1);
build (1,1,cnt);
for (int i = 1;i <= m;++i)
{
int ty = read (),x = read (),y,val;
if (ty == 1)
{
y = read (),val = read ();
upd (x,y,val);
}
if (ty == 2)
{
y = read ();
printf ("%lld\n",ask (x,y));
}
if (ty == 3)
{
val = read ();
modify (1,1,cnt,dfn[x],dfn[x] + sz[x] - 1,val);//子树内的 dfs 序连续
}
if (ty == 4) printf ("%lld\n",query (1,1,cnt,dfn[x],dfn[x] + sz[x] - 1));
}
return 0;
}
inline int read ()
{
int s = 0;int f = 1;
char ch = getchar ();
while ((ch < '0' || ch > '9') && ch != EOF)
{
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9')
{
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * f;
}
void add (int u,int v)
{
to[++dcnt] = v;
nxt[dcnt] = head[u];
head[u] = dcnt;
}
void build (int cur,int l,int r)
{
if (l == r)
{
tree[cur] = a[rk[l]];
return ;
}
int mid = (l + r) >> 1;
build (cur << 1,l,mid);
build (cur << 1 | 1,mid + 1,r);
tree[cur] = tree[cur << 1] + tree[cur << 1 | 1];
}
void pushdown (int cur,int l,int r)
{
if (!tmp[cur]) return ;
int mid = (l + r) >> 1;
tree[cur << 1] += (mid - l + 1) * tmp[cur];
tree[cur << 1 | 1] += (r - mid) * tmp[cur];
tmp[cur << 1] += tmp[cur];tmp[cur << 1 | 1] += tmp[cur];
tmp[cur] = 0;
}
void modify (int cur,int l,int r,int x,int y,int v)
{
if (x <= l && y >= r)
{
tree[cur] += (r - l + 1) * v;
tmp[cur] += v;
return ;
}
pushdown (cur,l,r);
int mid = (l + r) >> 1;
if (x <= mid) modify (cur << 1,l,mid,x,y,v);
if (y > mid) modify (cur << 1 | 1,mid + 1,r,x,y,v);
tree[cur] = tree[cur << 1] + tree[cur << 1 | 1];
}
ll query (int cur,int l,int r,int x,int y)
{
if (x <= l && y >= r) return tree[cur];
pushdown (cur,l,r);
int mid = (l + r) >> 1;ll s = 0;
if (x <= mid) s += query (cur << 1,l,mid,x,y);
if (y > mid) s += query (cur << 1 | 1,mid + 1,r,x,y);
return s % p;
}
void dfs1 (int u,int fa)
{
sz[u] = 1;
dep[u] = dep[fa] + 1;
f[u] = fa;
for (int i = head[u];i;i = nxt[i])
{
int v = to[i];
if (v == fa) continue;
dfs1 (v,u);
sz[u] += sz[v];
if (sz[hson[u]] < sz[v]) hson[u] = v;//子块大的便是重儿子
}
}
void dfs2 (int u,int fa)
{
if (hson[u])
{
top[hson[u]] = top[u];
dfn[hson[u]] = ++cnt;
rk[cnt] = hson[u];
dfs2 (hson[u],u);//优先对重儿子进行剖分,从而保证重链上的 dfn 序连续
}
for (int i = head[u];i;i = nxt[i])
{
int v = to[i];
if (top[v]) continue;
top[v] = v;//单独以 v 作为重链的顶端
dfn[v] = ++cnt;
rk[cnt] = v;
dfs2 (v,v);
}
}
void upd (int x,int y,int v)// x - y 的最短路径上,所有点的权值均加上 v
{
int fx = top[x],fy = top[y];
while (fx != fy)
{
if (dep[fx] < dep[fy]) swap (fx,fy),swap (x,y);//选择深度大的向上跳
modify (1,1,cnt,dfn[fx],dfn[x],v);//某一段链的更新 注意 f[x] 的 dfn 序更小
x = f[fx],fx = top[x];
}
if (dep[x] > dep[y]) swap (x,y);
modify (1,1,cnt,dfn[x],dfn[y],v);//最后 x 与 y 在同一条链上
}
ll ask (int x,int y)
{
int fx = top[x],fy = top[y];
ll ans = 0;
while (fx != fy)
{
if (dep[fx] < dep[fy]) swap (fx,fy),swap (x,y);
ans += query (1,1,cnt,dfn[fx],dfn[x]);
x = f[fx],fx = top[x];
}
if (dep[x] > dep[y]) swap (x,y);
ans += query (1,1,cnt,dfn[x],dfn[y]);
return ans % p;
}
  1. 拓展应用

求最近公共祖先。不断向上跳重链,当跳到同一条重链上时,深度较小的结点即为 $\texttt{LCA}$。

核心代码就是上一模板题的查询操作的修改。

1
2
3
4
5
6
7
8
9
10
11
int LCA (int x,int y)
{
int fx = top[x],fy = top[y];
while (fx != fy)
{
if (dep[fx] < dep[fy]) swap (fx,fy),swap (x,y);
x = f[fx],fx = top[x];
}
if (dep[x] > dep[y]) swap (x,y);
return x;
}

可持久化线段树

P3919 【模板】可持久化线段树 1(可持久化数组)

  1. 作用
  • 可保留每一个历史版本。

  • 实现一些强制在线的功能。

  1. 性质
  • 每一次修改后所增加的节点数最大为 $\log n$。

  • 具有若干个根,且每一个根均可以构成一棵完整的线段树。

  1. 实现过程
  • 只对进行修改的结点进行复制处理。

  • 直接新开一块内存储存新节点,建树等同于新建节点。

  • 对于每一个根,对应着一个版本,因此若要处理某一版本 $i$,访问或更新时直接使用 root[i] 即可。

  • 其它过程与普通线段树差不多。

  • 注意空间开到 N << 5 差不多。

  1. 完整代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
using namespace std;
const int MAX = 1e6 + 5;
const int MOD = 1e9 + 7;
inline int read ();
int n,m,cnt,a[MAX],root[MAX << 5];
struct node
{
int dl,dr,v;
} tree[MAX << 5];
int make (int cur,int l,int r);
int modify (int cur,int l,int r,int x,int val);
int query (int cur,int l,int r,int x);
int main ()
{
//freopen (".in","r",stdin);
//freopen (".out","w",stdout);
n = read ();m = read ();
for (int i = 1;i <= n;++i) a[i] = read ();
root[0] = make (0,1,n);//初始的建树过程
for (int i = 1;i <= m;++i)
{
int k = read (),ty = read (),x = read (),y;
if (ty == 1)
{
y = read ();
root[i] = modify (root[k],1,n,x,y);
}
else
{
printf ("%d\n",query (root[k],1,n,x));
root[i] = root[k];//保存当前版本
}
}
return 0;
}
inline int read ()
{
int s = 0;int f = 1;
char ch = getchar ();
while ((ch < '0' || ch > '9') && ch != EOF)
{
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9')
{
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * f;
}
int make (int cur,int l,int r)
{
cur = ++cnt;
if (l == r)
{
tree[cur].v = a[l];
return cur;
}
int mid = (l + r) >> 1;
tree[cur].dl = make (tree[cur].dl,l,mid);
tree[cur].dr = make (tree[cur].dr,mid + 1,r);
return cur;
}
int modify (int cur,int l,int r,int x,int val)
{
tree[++cnt] = tree[cur];//新开一个节点
cur = cnt;
if (l == r)
{
tree[cur].v = val;
return cur;
}
int mid = (l + r) >> 1;
if (x <= mid) tree[cur].dl = modify (tree[cur].dl,l,mid,x,val);
else tree[cur].dr = modify (tree[cur].dr,mid + 1,r,x,val);
return cur;
}
int query (int cur,int l,int r,int x)
{
if (l == r) return tree[cur].v;
int mid = (l + r) >> 1;
if (x <= mid) return query (tree[cur].dl,l,mid,x);
else return query (tree[cur].dr,mid + 1,r,x);
}

P3834 【模板】可持久化线段树 2P1533 可怜的狗狗

和上题差不多,用 root[i] 表示范围为 $[1,i]$ 的一个版本,查询时用前缀和的思想进行线段树的相减。

即在求第 $k$ 小的数时,先判断是否在左子树内,是则递归;否则递归右子树并将问题变为求第 $k - x$ 小的数($x$ 表示进行相减后的区间的数的个数)。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
using namespace std;
const int MAX = 3e5 + 5;
const int MOD = 1e9 + 7;
inline int read ();
struct T
{
int l,r,sum;
} tree[MAX << 5];
int n,m,len,cnt,a[MAX],num[MAX],root[MAX];
int build (int cur,int l,int r);
int modify (int cur,int l,int r,int x);
int query (int l,int r,int dx,int dy,int k);
int main ()
{
//freopen (".in","r",stdin);
//freopen (".out","w",stdout);
n = read ();m = read ();
for (int i = 1;i <= n;++i) a[i] = read (),num[i] = a[i];
sort (num + 1,num + 1 + n);
len = unique (num + 1,num + n + 1) - num - 1;
root[0] = build (1,1,len);
for (int i = 1;i <= n;++i)
{
int x = lower_bound (num + 1,num + 1 + len,a[i]) - num;
root[i] = modify (root[i - 1],1,len,x);
}
for (int i = 1;i <= m;++i)
{
int x = read (),y = read (),k = read ();
printf ("%d\n",num[query (1,len,root[x - 1],root[y],k)]);
}
return 0;
}
inline int read ()
{
int s = 0;int f = 1;
char ch = getchar ();
while ((ch < '0' || ch > '9') && ch != EOF)
{
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9')
{
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * f;
}
int build (int cur,int l,int r)
{
cur = ++cnt;
if (l == r) return cur;
int mid = (l + r) >> 1;
tree[cur].l = build (tree[cur].l,l,mid);
tree[cur].r = build (tree[cur].r,mid + 1,r);
return cur;
}
int modify (int cur,int l,int r,int x)
{
tree[++cnt] = tree[cur];
tree[cnt].sum = tree[cur].sum + 1;
cur = cnt;
if (l == r) return cur;
int mid = (l + r) >> 1;
if (x <= mid) tree[cur].l = modify (tree[cur].l,l,mid,x);
else tree[cur].r = modify (tree[cur].r,mid + 1,r,x);
return cur;
}
int query (int l,int r,int dx,int dy,int k)
{
if (l == r) return l;
int x = tree[tree[dy].l].sum - tree[tree[dx].l].sum;//线段树相减
int mid = (l + r) >> 1;
if (x >= k) return query (l,mid,tree[dx].l,tree[dy].l,k); // 左子树能包含第 k 小
else return query (mid + 1,r,tree[dx].r,tree[dy].r,k - x);
}

李超线段树

P4097 [HEOI2013]Segment

支持区间加入一条平面上的线段,单点询问函数最值,用一次函数的形式储存线段。

有两种操作:

  • 加入一个值域为 $[l,r]$ 的一次函数

  • 给定 $k$ 后求与 $x = k$ 相交时 $y$ 最大的线段编号(相同时按照字典序)

拿新线段在中点处的值与原最优线段在中点处的值作比较,择优选取。主要可以分为以下几种情况:

  • 之前没有线段,则新线段为最优解。

  • 新线段完全优于原线段,直接更新整段。

  • 新线段完全劣于原线段,直接忽略。

  • 新线段与原线段有交点,继续进行分类讨论,两种情况继续对应更新左右子区间的两种情形:

  1. 交点在中点中间及左侧

  2. 交点在中点右侧

在写法上大致与普通线段树相同,但是有几个细节需要注意一下:

  • 对于一条线段,如果直接用点斜式表示,斜率可能不存在(但由题知不会为 $0$),所以可以特判此情况将其标记为 $0$。

  • 输入时需要调整两端点的位置,确保 $x_0 < x_1$ 从而方便计算。

  • 在更新最优子区间时相当于标记下放,在此因减少分类讨论量默认为当前最优子区间的中点小于新加入的线段,若不符合强制交换一下,然后在此基础上更新左区间或右区间。

  • 若不存在任何一条线段与查询直线有交,则输出 $0$;若有多条线段与查询直线的交点纵坐标都是最大的,则输出编号最小的线段。以上为输出的要求,对于不相交的线段,特判一下即可;而对于输出答案的方式,写一个函数进行比较,纵坐标相同时比较编号,否则比较纵坐标即可。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
using namespace std;
const int MAX = 1e5 + 5;
const int MOD1 = 39989,MOD2 = 1e9 + 1;
inline int read ();
struct line
{
double num;
int id;
};
struct node
{
double k,b;
} a[MAX];
int t,ans,cnt,root[MAX << 2];
line cmp (line x,line y);
double calc (int cur,int x);
void add (int sx,int sy,int fx,int fy);
void modify (int cur,int l,int r,int x,int y,int id);
void pushdown (int cur,int l,int r,int id);
line query (int cur,int l,int r,int x);
int main ()
{
//freopen (".in","r",stdin);
//freopen (".out","w",stdout);
t = read ();
while (t--)
{
bool ty = read ();
if (!ty)
{
int k = read ();
ans = query (1,1,MOD1,(ans + k - 1) % + MOD1 + 1).id;
printf ("%d\n",ans);
}
else
{
int sx = read (),sy = read (),fx = read (),fy = read ();
sx = (ans + sx - 1) % MOD1 + 1;fx = (ans + fx - 1) % MOD1 + 1;
sy = (ans + sy - 1) % MOD2 + 1;fy = (ans + fy - 1) % MOD2 + 1;
if (sx > fx) swap (sx,fx),swap (sy,fy);
add (sx,sy,fx,fy);
modify (1,1,MOD1,sx,fx,cnt);
}
}
return 0;
}
inline int read ()
{
int s = 0;int f = 1;
char ch = getchar ();
while ((ch < '0' || ch > '9') && ch != EOF)
{
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9')
{
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * f;
}
line cmp (line x,line y)
{
if (x.num != y.num) return x.num < y.num ? y : x;
else return x.id < y.id ? x : y;
}
double calc (int cur,int x)
{
return a[cur].k * x + a[cur].b;
}
void add (int sx,int sy,int fx,int fy)// 已经保证 sx < fx
{
++cnt;
if (sx == fx) a[cnt] = {0,max (sy,fy)};// 垂直与 x 轴时斜率不存在
else
{
a[cnt].k = (double)(fy - sy) / (fx - sx);
a[cnt].b = sy - a[cnt].k * sx;
}
}
void modify (int cur,int l,int r,int x,int y,int id)
{
if (x <= l && y >= r)
{
pushdown (cur,l,r,id);
return ;
}
int mid = (l + r) >> 1;
if (x <= mid) modify (cur << 1,l,mid,x,y,id);
if (y > mid) modify (cur << 1 | 1,mid + 1,r,x,y,id);
}
void pushdown (int cur,int l,int r,int id)// 更新最优的区间
{
int mid = (l + r) >> 1;
if (calc (root[cur],mid) < calc (id,mid)) swap (root[cur],id);
if (calc (root[cur],l) < calc (id,l)) pushdown (cur << 1,l,mid,id);
if (calc (root[cur],r) < calc (id,r)) pushdown (cur << 1 | 1,mid + 1,r,id);
}
line query (int cur,int l,int r,int x)
{
if (x > r || x < l) return {0,0};
line s = {calc (root[cur],x),root[cur]};
if (l == r) return s;
int mid = (l + r) >> 1;
if (x <= mid) s = cmp (s,query (cur << 1,l,mid,x));
else s = cmp (s,query (cur << 1 | 1,mid + 1,r,x));
return s;
}

再来一道几乎属于是模板题的题目—P4254 [JSOI2008]Blue Mary开公司

还是一样求 $x = k$ 时的最大值,不过还是有一些细节要注意:

  • 虽然这是一个一次函数,但是由题目的数据范围可以看作一条 $[1,500000]$ 的线段。

  • 第一天的收益为 $S$,之后每天增长 $P$,所以函数的斜率与截距分别为 $P$ 与 $S - P$。

  • 询问的是函数的值而非编号,同时下取整至百位。

具体代码和上一个差不多,就不写出了。

莫队

简单来说就是一类把 $[l,r]$ 的答案用 $O(1)$ 的算法扩展到 $[l,r - 1],[l,r + 1],[l - 1,r],[l + 1,r]$ 的数据结构。由于涉及到分块,所以时间复杂度为 $O(n\sqrt{n})$。下面我们来简单证明一下莫队的时间复杂度。

首先考虑 $n,m$ 同阶的情况。令每一块 $l$ 的最大值分别为 $\max_1,\max_2,\cdots,\max_{\lceil \sqrt n \rceil}$。每一个块的第一个询问暴力查找,时间复杂度为 $O(n)$。之后在极端情况下每个 $R$ 均为 $n$,$L$ 为 $\max_i \to \max_{i - 1}$ 或 $\max_{i - 1} \to \max_i$。$R$ 同块内由于已经有序,最多为 $O(n)$,全部为 $O(n\sqrt n)$;$L$ 全部为 $O(\sqrt n (\max_i - \max_{i - 1}))$,通过裂项求和已知为 $O(\sqrt n (\max_ {\lceil n \rceil - 1}))$,最坏情况下也就是 $O(n\sqrt n)$。综上,莫队此时的时间复杂度为 $O(n\sqrt n)$。

对于 $n,m$ 不同阶的情况,设块长度为 $S$,则复杂度为 $O(n \times \dfrac{n}{S} + mS)$,由基本不等式可知,当 $S = \dfrac{n}{\sqrt m}$ 时取到最小值,即 $O(n \sqrt m)$。

莫队的模板如下(需要特别注意四个 while 循环的位置,前两步先扩大区间,即 --l++r,后两步再缩小区间,即 ++l--r,从而保持 $l \le r + 1$,使得区间保持合法):

1
2
3
4
5
6
7
8
9
10
11
12
13
void solve ()
{
sort (a + 1,a + 1 + m);
for (int i = 1;i <= m;++i)
{
node q = a[i];
while (l > q.l) move (--l,1);
while (r < q.r) move (r++,1);
while (l < q.l) move (l++,1);
while (r > q.r) move (--r,1);
ans[q.id] = nw;//The function "move" updates nw
}
}

一道例题 P1494 [国家集训队] 小 Z 的袜子

在排序之后,设 $f_i$ 表示颜色 $i$ 当前出现的次数,则更新答案时,扩大区间后答案为 $\tbinom{f_k + 1}{2} - \tbinom{f_k}{2}$,缩小区间为 $\tbinom{f_k}{2} - \tbinom{f_k - 1}{2}$。令答案为 $k$,则对于一个 $[l,r]$ 的询问,答案即为 $\dfrac{k}{\tbinom{l - r + 1}{2}}$。经过计算可以化简表达式, $\tbinom{f_k + 1}{2} - \tbinom{f_k}{2} = \dfrac{(f_k + 1)f_k - f_k(f_k - 1)}{2} = \dfrac{2f_k}{2} = f_k$。

因此我们通过莫队维护,时间复杂度为 $O(n\sqrt{n})$。完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
using namespace std;
const int MAX = 5e4 + 5;
const int MOD = 1e9 + 7;
inline int read ();
struct node
{
int l,r,id;
} a[MAX];
int n,m,s,x = 1,y,cnt[MAX],p[MAX];
ll sum,ansx[MAX],ansy[MAX];
bool cmp (node x,node y);
void add (int x);
void del (int x);
ll gcd (ll x,ll y);
int main ()
{
//freopen (".in","r",stdin);
//freopen (".out","w",stdout);
n = read (),m = read ();s = sqrt (n);
for (int i = 1;i <= n;++i) p[i] = read ();
for (int i = 1;i <= m;++i) a[i].l = read (),a[i].r = read (),a[i].id = i;
sort (a + 1,a + 1 + m,cmp);
for (int i = 1;i <= m;++i)
{
if (a[i].l == a[i].r)
{
ansx[a[i].id] = 0,ansy[a[i].id] = 1;
continue;
}
while (x > a[i].l) add (p[--x]);
while (y < a[i].r) add (p[++y]);
while (x < a[i].l) del (p[x++]);
while (y > a[i].r) del (p[y--]);
ansx[a[i].id] = sum;
ansy[a[i].id] = 1ll * (y - x + 1) * (y - x) >> 1;
}
for (int i = 1;i <= m;++i)
{
if (!ansx[i]) ansy[i] = 1;
else
{
ll g = gcd (ansx[i],ansy[i]);
ansx[i] /= g,ansy[i] /= g;
}
printf ("%lld/%lld\n",ansx[i],ansy[i]);
}
return 0;
}
inline int read ()
{
int s = 0;int f = 1;
char ch = getchar ();
while ((ch < '0' || ch > '9') && ch != EOF)
{
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9')
{
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * f;
}
bool cmp (node x,node y) {return (x.l / s == y.l / s) ? x.r < y.r : x.l < y.l;}
void add (int x) {sum += cnt[x]++;}
void del (int x) {sum -= --cnt[x];}
ll gcd (ll x,ll y) {return (!y) ? x : gcd (y,x % y);}

关于普通莫队,还有一个小小的优化。通过奇偶化排序,可以优化 $r$ 指针的移动次数,从而加快效率。因此上题的 cmp 函数可以修改如下:

1
2
3
4
5
bool cmp (node x,node y)
{
if (x.l / s != y.l / s) return x.l < y.l;
return (x.l / s) & 1 ? x.r < y.r : x.r > y.r;
}

普通莫队不支持修改,但我们可以加入一个时间的信息来支持修改。时间这一维的加入让我们移动的方向由四种变为六种,同时排序的关键字也随之增加。

按照 $n^\frac{2}{3}$ 的大小分成了 $n^{\frac{1}{3}}$ 块,按照左右时间这一关键字顺序进行排序,之后的移动有两种情况:

  • 左右端点所在块不变,时间向后移动,时间复杂度为 $O(n)$。

  • 左右端点所在块改变,最快情况下,时间移动 $n$ 个单位,时间复杂度为 $O(n)$。

由于左右端点所在块各有 $n^{\frac{1}{3}}$ 种,加上时间的复杂度,总共为 $O(n^{\frac{1}{3} + \frac{1}{3} + 1}) = O(n^{\frac{5}{3}})$。

于是来看一下这一题,P1903 [国家集训队] 数颜色 / 维护队列

加入时间这个维度后,按照 $l,r,t$ 的顺序从小到大排序。现在唯一多的一个步骤就是处理时间这个维度。若当前将第 $p$ 个位置上的颜色 $x$ 改为 $y$,莫队当前区间为 $[l,r]$。有以下两种情况:

  • 加入修改。若该位置在区间 $[l,r]$ 内,则将其删除后改为 $y$,否则直接加上新的颜色 $y$。

  • 还原修改。转换一下,相当于将 $p$ 个位置由颜色 $y$ 改为 $x$。

写的时候两个操作用了两个结构体区别开。$Q$ 是存放询问的内容,里面的元素分别为询问编号,左端点,右端点,时间;$M$ 是修改,里面的元素分别为位置,旧颜色,新颜色。最后还要注意一下空间大小,具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
using namespace std;
const int MAX = 2e5 + 5;
const int MOD = 1e9 + 7;
inline int read ();
struct Q {int id,l,r,t;} q[MAX];
struct M {int p,pre,nw;} a[MAX];
int n,m,s,tot_q,tot_m,l = 1,r,Time,sum;
int col[MAX],ans[MAX],pos[MAX],cnt[MAX * 10];
bool cmp (Q x,Q y);
void solve (int p,int x);
void add (int x);
void del (int x);
int main ()
{
//freopen (".in","r",stdin);
//freopen (".out","w",stdout);
n = read ();m = read ();s = (int) pow (n,2 / 3.0);
for (int i = 1;i <= n;++i) col[i] = read (),pos[i] = col[i];
for (int i = 1;i <= m;++i)
{
char ty;scanf ("%c",&ty);
int x = read (),y = read ();
if (ty == 'R') a[++tot_m] = (M){x,pos[x],y},pos[x] = y;
else q[++tot_q] = (Q){tot_q,x,y,tot_m};
}
sort (q + 1,q + 1 + tot_q,cmp);
for (int i = 1;i <= tot_q;++i)
{
while (Time < q[i].t) solve (a[Time + 1].p,a[Time + 1].nw),++Time;
while (Time > q[i].t) solve (a[Time].p,a[Time].pre),--Time;
while (l > q[i].l) add (col[--l]);
while (r < q[i].r) add (col[++r]);
while (l < q[i].l) del (col[l++]);
while (r > q[i].r) del (col[r--]);
ans[q[i].id] = sum;
}
for (int i = 1;i <= tot_q;++i) printf ("%d\n",ans[i]);
return 0;
}
inline int read ()
{
int s = 0;int f = 1;
char ch = getchar ();
while ((ch < '0' || ch > '9') && ch != EOF)
{
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9')
{
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * f;
}
bool cmp (Q x,Q y)
{
if (x.l / s != y.l / s) return x.l < y.l;
if (x.r / s != y.r / s) return x.r < y.r;
return x.t < y.t;
}
void solve (int p,int x)
{
if (l <= p && p <= r) add (x),del (col[p]);
col[p] = x;
}
void add (int x)
{
++cnt[x];
if (cnt[x] == 1) ++sum;
}
void del (int x)
{
if (cnt[x] == 1) --sum;
--cnt[x];
}
  • 回滚莫队

在区间转移时,可能会出现删减操作无法实现的问题,而这时可以考虑是使用回滚莫队解决问题。

以这道例题 AT1219 JOISC 2014 Day1 历史研究 为例。该题的删除操作用普通莫队难以解决,由于询问离线,故可以考虑回滚莫队。

设块大小为 $s$,对于第 $i$ 的块有左右端点 $l_i,r_i$,莫队当前区间为 $[L,R]$。先对询问按之前的方式排序,然后根据不同情况分类讨论:

  • 当前左端点对应的块的编号为 $x$,若与上一个询问的左端点所处编号为 $y$ 的块的左端点不同,那么有 $L = r_x + 1,R = r_x$。

  • 询问的左右端点处于同一个块,则直接暴力循环求解。

  • 询问的左右端点处于不同块,则有:

    1. 询问右端点大于 $R$:扩展莫队区间右端点。
    2. 同理扩展莫队区间左端点。
    3. 回答询问后撤销扩展莫队区间左端点,并将其回滚至 $r_x + 1$。
      现在来分析一下时间复杂度,块大小仍然是 $s$。还是分类讨论,对于一个询问:
  • 若左右端点在同一个块内,则普通的块内查询即可,显然遍历的时间复杂度为 $O(s)$。

  • 若不在同一个块内,基于排序的方式,固定左端点,同时右端点单调。那么移动右端点所需的时间复杂度为 $O(n)$,左端点最多移动 $s$ 次。

所以总复杂度为 $O(s \times m + n \times
\dfrac{n}{s})$,有基本不等式易知当 $s = \dfrac{n}{\sqrt{m}}$ 时最优,为 $O(n \sqrt{m})$。

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
using namespace std;
const int MAX = 1e5 + 5;
const int MOD = 1e9 + 7;
inline int read ();
struct query
{
int l,r,id;
} qu[MAX];
int n,q,tot,times,s;
int a[MAX],b[MAX],cnt[MAX],dcnt[MAX],L[MAX],R[MAX],pos[MAX];
ll ans[MAX],sum;
bool cmp (query xx,query yy);
void add (int x,ll &v);
void del (int x);
int main ()
{
//freopen (".in","r",stdin);
//freopen (".out","w",stdout);
n = read ();q = read ();
for (int i = 1;i <= n;++i) a[i] = read (),b[++times] = a[i];
for (int i = 1;i <= q;++i) qu[i].l = read (),qu[i].r = read (),qu[i].id = i;
s = n / sqrt (q);tot = ceil ((double)n / s);
sort (b + 1,b + 1 + times);
times = unique (b + 1,b + 1 + times) - b - 1;
for (int i = 1;i <= n;++i) a[i] = lower_bound (b + 1,b + 1 + times,a[i]) - b;//离散化
for (int i = 1;i <= tot;++i)//初始化
{
L[i] = (i - 1) * s + 1;
R[i] = min (n,i * s);
}
for (int i = 1;i <= tot;++i)
for (int j = L[i];j <= R[i];++j) pos[j] = i;
sort (qu + 1,qu + 1 + q,cmp);
int l = 1,r = 0,la;
for (int i = 1;i <= q;++i)
{
if (pos[qu[i].l] == pos[qu[i].r])// 左右端点在同一块内
{
for (int j = qu[i].l;j <= qu[i].r;++j) ++dcnt[a[j]];
for (int j = qu[i].l;j <= qu[i].r;++j) ans[qu[i].id] = max (ans[qu[i].id],1ll * b[a[j]] * dcnt[a[j]]);
for (int j = qu[i].l;j <= qu[i].r;++j) --dcnt[a[j]];
continue;
}
if (pos[qu[i].l] != la) //一个新块 初始化莫队区间
{
while (r > R[pos[qu[i].l]]) del (a[r--]);
while (l < R[pos[qu[i].l]] + 1) del (a[l++]);
l = R[pos[qu[i].l]] + 1;
la = pos[qu[i].l];sum = 0;
}
while (r < qu[i].r) add (a[++r],sum);
int nwl = l;ll nwsum = sum;//记录当前 l 从而进行回滚
while (l > qu[i].l) add (a[--l],nwsum);
ans[qu[i].id] = nwsum;
while (l < nwl) del (a[l++]);//回滚
l = nwl;
}
for (int i = 1;i <= q;++i) printf ("%lld\n",ans[i]);
return 0;
}
inline int read ()
{
int s = 0;int f = 1;
char ch = getchar ();
while ((ch < '0' || ch > '9') && ch != EOF)
{
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9')
{
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * f;
}
bool cmp (query xx,query yy) {return (pos[xx.l] != pos[yy.l]) ? pos[xx.l] < pos[yy.l] : xx.r < yy.r;}
void add (int x,ll &v) {++cnt[x];v = max (v,1ll * cnt[x] * b[x]);}
void del (int x) {--cnt[x];}

再来一道P5906 【模板】回滚莫队&不删除莫队。简要题意:求一段区间 $[l,r]$ 种相同数中下标的极值,多次询问并且离线。

还是一样,用普通的莫队易于扩展区间而不利于缩小区间,仍然考虑回滚莫队。记得记录最大最小值下标的也要回滚,同时变量不要冲突,多次覆盖。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
using namespace std;
const int MAX = 2e5 + 5;
const int MOD = 1e9 + 7;
inline int read ();
struct Q
{
int l,r,id;
} q[MAX];
int n,m,times,s,tot,l = 1,r,la,sum;
int a[MAX],b[MAX],L[MAX],R[MAX];
int vis[MAX],pos[MAX],ans[MAX];
int mx[MAX],mn[MAX],dmx[MAX],dmn[MAX],nmx[MAX],nmn[MAX];
bool cmp (Q xx,Q yy);
void add (int x,int &v,bool ty);
void del (int x,bool ty);
int main ()
{
n = read ();
for (int i = 1;i <= n;++i) a[i] = read (),b[++times] = a[i];
m = read ();
for (int i = 1;i <= m;++i) q[i].l = read (),q[i].r = read (),q[i].id = i;
sort (b + 1,b + 1 + times);
times = unique (b + 1,b + 1 + times) - b - 1;
for (int i = 1;i <= n;++i) a[i] = lower_bound (b + 1,b + 1 + times,a[i]) - b;
s = n / sqrt (m);tot = ceil ((double) n / s);
for (int i = 1;i <= tot;++i)
{
L[i] = s * (i - 1) + 1;
R[i] = min (n,s * i);
}
for (int i = 1;i <= tot;++i)
for (int j = L[i];j <= R[i];++j) pos[j] = i;
sort (q + 1,q + 1 + m,cmp);
for (int i = 1;i <= times;++i) mn[i] = dmn[i] = INF;
for (int i = 1;i <= m;++i)
{
if (pos[q[i].l] == pos[q[i].r])
{
for (int j = q[i].l;j <= q[i].r;++j) nmx[a[j]] = 0,nmn[a[j]] = INF;
for (int j = q[i].l;j <= q[i].r;++j) nmx[a[j]] = max (nmx[a[j]],j),nmn[a[j]] = min (nmn[a[j]],j);
for (int j = q[i].l;j <= q[i].r;++j) ans[q[i].id] = max (ans[q[i].id],nmx[a[j]] - nmn[a[j]]);
continue;
}
if (pos[q[i].l] != la)
{
while (r > R[pos[q[i].l]]) del (r--,0);
while (l < R[pos[q[i].l]] + 1) del (l++,0);
r = R[pos[q[i].l]];l = R[pos[q[i].l]] + 1;
la = pos[q[i].l];sum = 0;
}
while (r < q[i].r) add (++r,sum,0);
int nwsum = sum,nwl = l;
while (nwl > q[i].l) add (--nwl,nwsum,1);
while (nwl < l) del (nwl++,1);
ans[q[i].id] = nwsum;
}
for (int i = 1;i <= m;++i) printf ("%d\n",ans[i]);
return 0;
}
inline int read ()
{
int s = 0;int f = 1;
char ch = getchar ();
while ((ch < '0' || ch > '9') && ch != EOF)
{
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9')
{
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * f;
}
bool cmp (Q xx,Q yy) {return (pos[xx.l] != pos[yy.l]) ? xx.l < yy.l : xx.r < yy.r;}
void add (int x,int &v,bool ty)
{
if (!ty)
{
mx[a[x]] = max (mx[a[x]],x);mn[a[x]] = min (mn[a[x]],x);
v = max (v,mx[a[x]] - mn[a[x]]);
}
else
{
dmn[a[x]] = min (dmn[a[x]],min (mn[a[x]],x));dmx[a[x]] = max (dmx[a[x]],max (mx[a[x]],x));
v = max (v,dmx[a[x]] - dmn[a[x]]);
}
}
void del (int x,bool ty)
{
if (!ty) mx[a[x]] = 0,mn[a[x]] = INF;
else dmx[a[x]] = 0,dmn[a[x]] = INF;
}