HIR180's diary

ICPC World Finals 2022 を集大成に

2014-06-11

AOJ 2292 23:04

文字列強化週間なのでといてみた。APIO 2014の1問目の上位互換。

解法: manacherのほげで適当に各場所を中心とする回文長を求める。

次に長さlenの回文について考える。

lenは奇数とする。(偶数も同様に処理できる。)

場所i,jに対しiを中心とする長さlenの文字列とjを中心とする長さlenの文字列が等しい回文であるときuf木上でつなぐ。

長さを大きい順に処理すればつながれたものを切る必要がない。

uf木上の全グループに対して、Sに属してる文字列の数*Tに属してる文字列の数の総和が答えになる。

これは更新するたびに計算すれば特に処理する必要はない。

これでO(N log N)で解ける。


#include 
#include 
#include 
#include 

using namespace std;

typedef long long ll;

int par[100005],ran2[100005],a[100005],b[100005];
int sa[100005],lcp[100005],ran[100005],rad[200005],tmp[100005];
int n,k;
string S="",s1,s2;
ll res,cur;

void init()
{
	for(int i=0;i<100005;i++) par[i] = i;
	for(int i=0;i<100005;i++) ran2[i] = a[i] = b[i] = 0;
}

int find(int x)
{
	if(x == par[x]) return x;
	else return par[x] = find(par[x]);
}

void unite(int x,int y)
{
	x = find(x);
	y = find(y);
	
	if(x == y) return;
	
	if(ran2[x] < ran2[y])
	{
		par[x] = y;
		cur += 1LL*a[x]*b[y]+1LL*a[y]*b[x];
		a[y] += a[x];
		b[y] += b[x];
	}
	else
	{
		par[y] = x;
		cur += 1LL*a[x]*b[y]+1LL*a[y]*b[x];
		a[x] += a[y];
		b[x] += b[y];
		if(ran2[x] == ran2[y]) ran2[x]++;
	}
}

bool same(int x,int y)
{
	return find(x) == find(y);
}

void add_a(int x)
{
	x = find(x);
	cur += b[x];
	a[x]++;
}

void add_b(int x)
{
	x = find(x);
	cur += a[x];
	b[x]++;
}

void manacher()
{
	string str(2*n+1,'#');
	for(int i=0;i2+1] = S[i];
	
	int i = 0, j = 0;
	
	for(;i<2*n+1;)
	{
		while(i-j >= 0 && i+j < 2*n+1 && str[i-j] == str[i+j]) j++;
		rad[i] = j;
		int k = 1;
		while(i-k >= 0 && rad[i]-k > rad[i-k])
		{
			rad[i+k] = rad[i-k];
			++k;
		}
		i += k;
		j = max(j-k,0);
	}
}

bool compare_sa(int i,int j)
{
	if(ran[i] != ran[j]) return ran[i] < ran[j];
	else
	{
		int ri = i+k <= n? ran[i+k]: -1;
		int rj = j+k <= n? ran[j+k]: -1;
		return ri < rj;
	}
}

void construct_sa()
{
	for(int i=0;i<=n;i++)
	{
		sa[i] = i;
		ran[i] = i1;
	}
	for(k=1;k<=n;k*=2)
	{
		sort(sa,sa+n+1,compare_sa);
		
		tmp[sa[0]] = 0;
		for(int i=1;i<=n;i++)
		{
			tmp[sa[i]] = tmp[sa[i-1]] + compare_sa(sa[i-1],sa[i]);
		}
		for(int i=0;i<=n;i++)
		{
			ran[i] = tmp[i];
		}
	}
}

void construct_lcp()
{
	int h = 0;
	lcp[sa[0]] = 0;
	
	for(int i=0;iint j = sa[ran[i]-1];
		if(h) h--;
		while(i+h < n && j+h < n && S[i+h] == S[j+h]) h++;
		lcp[ran[i]-1] = h;
	}
}
vector<int>query[100005];
vector<int>in[100005];

int main()
{
	cin >> s1 >> s2;
	S = s1+"$"+s2;
	n = S.size();
	construct_sa();
	construct_lcp();
	manacher();
	
	//odd
	cur = 0;
	init();
	for(int i=0;ifor(int i=0;iint f = rad[i*2+1]-1;
		in[f].push_back(i);
	}
	int m = n;
	if(m%2==0)
	{
		m--;
	}
	for(int i=n;i>(m+1)/2;i--)
	{
		for(int j=0;j1]);
		}
	}
	for(int i=m;i>=1;i-=2)
	{
		for(int j=0;jif(in[i][j] < s1.size()) add_a(in[i][j]);
			else if(in[i][j] > s1.size()) add_b(in[i][j]);
		}
		for(int j=0;j1)/2].size();j++)
		{
			unite(sa[query[(i+1)/2][j]],sa[query[(i+1)/2][j]+1]);
		}
		res += cur;
	}
	
	//even
	cur = 0;
	init();
	for(int i=0;i<100005;i++) in[i].clear();
	for(int i=0;i<=2*n;i+=2)
	{
		int f = rad[i]-1;
		in[f].push_back(i/2);
	}
	m = n;
	if(m%2==1)
	{
		m--;
	}
	for(int i=n;i>m/2;i--)
	{
		for(int j=0;j1]);
		}
	}
	for(int i=m;i>=2;i-=2)
	{
		for(int j=0;jif(in[i][j] < s1.size())  add_a(in[i][j]);
			else if(in[i][j] > s1.size()) add_b(in[i][j]);
		}
		for(int j=0;j2].size();j++)
		{
			unite(sa[query[i/2][j]],sa[query[i/2][j]+1]);
		}
		res += cur;
	}
	cout << res << endl;
}