「2018 ACM-ICPC Nanjing - Onsite」M - Mediocre String Problem (字符串)

M-Mediocre String Problem
给定字符串S,T,求S的子串与T的前缀子串能够组成的回文串个数

题意

给定两个字符串$s ,t$,取$s$的子串$s’$和$t$的前缀子串$t’$,并使$|s’|>|t’|$.拼接$s’,t’$得到$str=s’+t’$,求能使$str$为回文串的总方案数。

思路

由于$|s’|>|t’|$,可令$s’=a+b,t’=c ,(|a|=|c|>0,|b|>0)$

因此$str=a+b+c$,由回文串性质可知,$b$为长度大于0的回文串,且$reverse(a)=c$

如,对于字符串$s=aabbcdedc,t=bbaa$,以$x=4$为例

$aabb|cdedc$

$aabb$

$;;abb$

$;;;;bb$

$;;;;;b$

$a,c$有以上4种取法,$b=c或b=cdedc$,共有2×4=8种情况

解法

对于$1≤i≤|s|$求出以$s$以第$i$位开头的回文串个数$CNT(i)$,可以采用Manacher,利用回文串性质差分求解;

翻转$s$,利用ex-KMP求解$reverse(s)$的后缀与$t$的最长公共前缀$LCP$;

对于原串$s$的第$x$位,能够组成的回文串个数为$LCP(x)·CNT(x+1)$,求和即为所求解.

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int maxn=1e6+10;

char s[maxn],t[maxn];

char tmp[maxn<<1];
int Len[maxn<<1],cnt[maxn];

int init(char *str)
{
int i,len=strlen(str);
tmp[0]='@';
for(int i=1;i<=2*len;i+=2)
{
tmp[i]='#';
tmp[i+1]=str[i/2];
}
tmp[2*len+1]='#';
tmp[2*len+2]='$';
tmp[2*len+3]=0;
return 2*len+1;
}

int manacher(char *str)
{
int mx=0,ans=0,pos=0;
int len=init(str);
for(int i=1;i<=len;i++)
{
if(mx>i) Len[i]=min(mx-i,Len[2*pos-i]);
else Len[i]=1;
while(tmp[i-Len[i]]==tmp[i+Len[i]]) Len[i]++;
if(Len[i]+i>mx) mx=Len[i]+i,pos=i;
}
for(int i=2;i<len;i++)
{
if(tmp[i]=='#'&&Len[i]==1) continue;
int x=i/2-Len[i]/2,y=(Len[i]-1)/2;
if((Len[i]-1)%2==0) y--;
cnt[x]++;
cnt[x+y+1]--;
}
}

int extend[maxn],nex[maxn];

void getNext(char *s)
{
int len=strlen(s);
nex[0]=len;
int pos=0;
while(pos+1<len&&s[pos]==s[pos+1]) pos++;
nex[1]=pos;
int k=1,L;
for(int i=2;i<len;i++)
{
pos=k+nex[k]-1;
L=nex[i-k];
if(i+L<=pos) nex[i]=L;
else {
int j=pos-i+1;
if(j<0) j=0;
while(i+j<len&&s[i+j]==s[j]) j++;
nex[i]=j;
k=i;
}
}
}

void getExtend(char *s,char *t)
{
int lens=strlen(s),lent=strlen(t);
getNext(t);
int pos=0;
while(pos<lens&&pos<lent&&s[pos]==t[pos]) pos++;
extend[0]=pos;
int k=0,L;
for(int i=1;i<lens;i++)
{
pos=k+extend[k]-1;
L=nex[i-k];
if(i+L<=pos) extend[i]=L;
else {
int j=pos-i+1;
if(j<0) j=0;
while(i+j<lens&&j<lent&&s[i+j]==t[j]) j++;
extend[i]=j;
k=i;
}
}
}

int main()
{
scanf("%s%s",s,t);
memset(cnt,0,sizeof cnt);
int lens=strlen(s),lent=strlen(t);
manacher(s);
for(int i=0;i<lens;i++) cnt[i]+=cnt[i-1];
reverse(s,s+lens);
getExtend(s,t);
long long ans=0;
for(int i=1;i<lens;i++)
ans+=1ll*cnt[lens-i]*extend[i];
printf("%lld\n",ans);
return 0;
}