「LightOJ-1428」Melody Comparison(后缀数组)

「LightOJ-1428」Melody Comparison
给定串A和串B,求串A本质不同且不包含串B的子串个数。

题解

KMP预处理出串A的每个后缀向右延伸的最远的不包含串B的位置rmax[i]。那么对于后缀sa[i],它的不包含串B的前缀个数为rmax[sa[i]]个。因为要求本质不同的子串个数,需要减去当前后缀和上一个后缀相同的前缀个数height[i]。答案即为$ans=\sum_{i=1}^nrmax[sa[i]]-min(height[i],rmax[i])$。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <bits/stdc++.h>

using namespace std;

const int maxn = 5e4 + 10;

char s[maxn], b[maxn];
int sa[maxn], t[maxn], t2[maxn], c[maxn], n;

void build_sa(int n, int m)
{
int *x = t, *y = t2;
for(int i = 0; i < m; i++) c[i] = 0;
for(int i = 0; i < n; i++) c[x[i] = s[i]]++;
for(int i = 1; i < m; i++) c[i] += c[i - 1];
for(int i = n - 1; i >= 0; i--) sa[--c[x[i]]] = i;
for(int k = 1; k <= n; k <<= 1)
{
int p = 0;
for(int i = n - k; i < n; i++) y[p++] = i;
for(int i = 0; i < n; i++) if(sa[i] >= k) y[p++] = sa[i] - k;
for(int i = 0; i < m; i++) c[i] = 0;
for(int i = 0; i < n; i++) c[x[y[i]]]++;
for(int i = 0; i < m; i++) c[i] += c[i - 1];
for(int i = n - 1; i >= 0; i--) sa[--c[x[y[i]]]] = y[i];
swap(x, y);
p = 1; x[sa[0]] = 0;
for(int i = 1; i < n; i++)
x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p - 1 : p++;
if(p >= n) break;
m = p;
}
}

int rk[maxn], height[maxn];

void getHeight()
{
for(int i = 1; i <= n; i++) rk[sa[i]] = i;
for(int i = 0, k = 0; i < n; i++)
{
if(k) k--;
int j = sa[rk[i] - 1];
while(s[i + k] == s[j + k]) k++;
height[rk[i]] = k;
}
for(int i = n; i >= 1; i--) ++sa[i], rk[i] = rk[i - 1];
}

int rmax[maxn], nex[maxn];

void getNext()
{
int n = strlen(b), i = 0, j = -1;
nex[i] = j;
while(i < n)
{
if(j == -1 || b[i] == b[j]) nex[++ i] = ++ j;
else j = nex[j];
}
}

void KMP()
{
int n = strlen(s), m = strlen(b), i = 0, j = 0, pos = 1;
getNext();
while(i < n)
{
if(j == -1 || s[i] == b[j]) i ++, j ++;
else j = nex[j];
if(j == m)
{
for(int p = pos; p <= i - m + 1; p ++) rmax[p] = i - p;
j = nex[j], pos = i - m + 2;
}
}
while(pos <= n) rmax[pos] = n - pos + 1, pos ++;
}

int main()
{
int T;
scanf("%d", &T);
for(int _ = 1; _ <= T; _ ++)
{
scanf("%s%s", s, b);
n = strlen(s);
KMP();
build_sa(n + 1, 130);
getHeight();
long long res = 0;
for(int i = 1; i <= n; i ++) res += rmax[sa[i]] - min(rmax[sa[i]], height[i]);
printf("Case %d: %lld\n", _, res);
}
return 0;
}