将题意进行简单的转换,枚举 $S_k$,然后再枚举其中的断点 $i$,将其分为 $S_k[:i]$ 和 $S_k[i + 1:]$。因此只需要求出前缀为 $S_k[:i]$ 和 $S_k[i + 1:]$ 的乘积,不难想到对前缀和后缀分别建立 $\texttt{Trie}$ 树。

但是问题并没有完全解决,手模第一个样例发现,$\texttt{A + AA}$ 和 $\texttt{AA + A}$ 均可以得到 $\texttt{AAA}$,也就是说目前算法存在重复计算。考虑一个贪心的思想,若一个较长的子串中存在所需串,则只将标记打在较长串上而不是其中的较短的所需串,显然这样标记严格覆盖的。那么,结合差分的思想,设 $f_{i,j}$ 表示以 $S_i[:j]$ 为前缀的个数,则 $f_{i,j} - f_{i,j + 1}$ 得到的一定是严格以 $S_i[:j]$ 为前缀的个数。在统计答案的时候,从长的串到短的串进行累积即可。

时间复杂度 $O(\sum |S_i|)$。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <map>
#include <vector>
#define init(x) memset (x,0,sizeof (x))
#define ll long long
#define ull unsigned long long
#define INF 0x3f3f3f3f
using namespace std;
const int MAX = 2e6 + 5;
const int MOD = 1e9 + 7;
inline int read ();
map <char,int> mp;
int n,pre_cnt = 1,sub_cnt = 1,len[MAX],pre_ch[MAX][5],sub_ch[MAX][5],pre_tot[MAX],sub_tot[MAX];
vector <ll> pre_ans[MAX],sub_ans[MAX];
vector <char> s[MAX];
char str[MAX];ll ans;
int main ()
{
//freopen (".in","r",stdin);
//freopen (".out","w",stdout);
mp['A'] = 1;mp['G'] = 2;mp['C'] = 3;mp['T'] = 4;
n = read ();
for (int i = 1;i <= n;++i)
{
scanf ("%s",str);
len[i] = strlen (str);
int u = 1;
for (int j = 0;j < len[i];++j)//two Tries
{
s[i].push_back (str[j]);
int c = mp[s[i][j]];
if (!pre_ch[u][c]) pre_ch[u][c] = ++pre_cnt;
u = pre_ch[u][c];
++pre_tot[u];
}
u = 1;
for (int j = len[i] - 1;~j;--j)
{
int c = mp[s[i][j]];
if (!sub_ch[u][c]) sub_ch[u][c] = ++sub_cnt;
u = sub_ch[u][c];
++sub_tot[u];
}
}
for (int i = 1;i <= n;++i)
{
int u = 1,sum;pre_ans[i].push_back (n - 1);
for (int j = 0;j < len[i];++j)
{
int c = mp[s[i][j]];
u = pre_ch[u][c];
pre_ans[i].push_back (pre_tot[u] - 1);//the empty situation
if (!u) break;
}
for (int j = 1;j <= len[i];++j) pre_ans[i][j - 1] -= pre_ans[i][j]; // subtraction gives a precise value
u = 1;sub_ans[i].push_back (n - 1);
for (int j = len[i] - 1;~j;--j)
{
int c = mp[s[i][j]];
u = sub_ch[u][c];
sub_ans[i].push_back (sub_tot[u] - 1);
if (!u) break;
}
for (int j = 1;j <= len[i];++j) sub_ans[i][j - 1] -= sub_ans[i][j];
}
for (int i = 1;i <= n;++i)
{
ll sum = 0;
for (int j = 0;j <= len[i];++j)
sum += sub_ans[i][len[i] - j],ans += sum * pre_ans[i][j];
// Equivalently,for j in range from len[i] to 0 is workable,but sum should be sub_pre
}
printf ("%lld\n",ans);
return 0;
}
inline int read ()
{
int s = 0;int f = 1;
char ch = getchar ();
while ((ch < '0' || ch > '9') && ch != EOF)
{
if (ch == '-') f = -1;
ch = getchar ();
}
while (ch >= '0' && ch <= '9')
{
s = s * 10 + ch - '0';
ch = getchar ();
}
return s * f;
}