2014-06-11
■ AOJ 2292
文字列強化週間なのでといてみた。APIO 2014の1問目の上位互換。
解法: manacherのほげで適当に各場所を中心とする回文長を求める。
次に長さlenの回文について考える。
場所i,jに対しiを中心とする長さlenの文字列とjを中心とする長さlenの文字列が等しい回文であるときuf木上でつなぐ。
長さを大きい順に処理すればつながれたものを切る必要がない。
uf木上の全グループに対して、Sに属してる文字列の数*Tに属してる文字列の数の総和が答えになる。
これは更新するたびに計算すれば特に処理する必要はない。
これでO(N log N)で解ける。
#include#include #include #include using namespace std; typedef long long ll; int par[100005],ran2[100005],a[100005],b[100005]; int sa[100005],lcp[100005],ran[100005],rad[200005],tmp[100005]; int n,k; string S="",s1,s2; ll res,cur; void init() { for(int i=0;i<100005;i++) par[i] = i; for(int i=0;i<100005;i++) ran2[i] = a[i] = b[i] = 0; } int find(int x) { if(x == par[x]) return x; else return par[x] = find(par[x]); } void unite(int x,int y) { x = find(x); y = find(y); if(x == y) return; if(ran2[x] < ran2[y]) { par[x] = y; cur += 1LL*a[x]*b[y]+1LL*a[y]*b[x]; a[y] += a[x]; b[y] += b[x]; } else { par[y] = x; cur += 1LL*a[x]*b[y]+1LL*a[y]*b[x]; a[x] += a[y]; b[x] += b[y]; if(ran2[x] == ran2[y]) ran2[x]++; } } bool same(int x,int y) { return find(x) == find(y); } void add_a(int x) { x = find(x); cur += b[x]; a[x]++; } void add_b(int x) { x = find(x); cur += a[x]; b[x]++; } void manacher() { string str(2*n+1,'#'); for(int i=0;i 2+1] = S[i]; int i = 0, j = 0; for(;i<2*n+1;) { while(i-j >= 0 && i+j < 2*n+1 && str[i-j] == str[i+j]) j++; rad[i] = j; int k = 1; while(i-k >= 0 && rad[i]-k > rad[i-k]) { rad[i+k] = rad[i-k]; ++k; } i += k; j = max(j-k,0); } } bool compare_sa(int i,int j) { if(ran[i] != ran[j]) return ran[i] < ran[j]; else { int ri = i+k <= n? ran[i+k]: -1; int rj = j+k <= n? ran[j+k]: -1; return ri < rj; } } void construct_sa() { for(int i=0;i<=n;i++) { sa[i] = i; ran[i] = i 1; } for(k=1;k<=n;k*=2) { sort(sa,sa+n+1,compare_sa); tmp[sa[0]] = 0; for(int i=1;i<=n;i++) { tmp[sa[i]] = tmp[sa[i-1]] + compare_sa(sa[i-1],sa[i]); } for(int i=0;i<=n;i++) { ran[i] = tmp[i]; } } } void construct_lcp() { int h = 0; lcp[sa[0]] = 0; for(int i=0;i int j = sa[ran[i]-1]; if(h) h--; while(i+h < n && j+h < n && S[i+h] == S[j+h]) h++; lcp[ran[i]-1] = h; } } vector<int>query[100005]; vector<int>in[100005]; int main() { cin >> s1 >> s2; S = s1+"$"+s2; n = S.size(); construct_sa(); construct_lcp(); manacher(); //odd cur = 0; init(); for(int i=0;i for(int i=0;i int f = rad[i*2+1]-1; in[f].push_back(i); } int m = n; if(m%2==0) { m--; } for(int i=n;i>(m+1)/2;i--) { for(int j=0;j 1]); } } for(int i=m;i>=1;i-=2) { for(int j=0;j if(in[i][j] < s1.size()) add_a(in[i][j]); else if(in[i][j] > s1.size()) add_b(in[i][j]); } for(int j=0;j 1)/2].size();j++) { unite(sa[query[(i+1)/2][j]],sa[query[(i+1)/2][j]+1]); } res += cur; } //even cur = 0; init(); for(int i=0;i<100005;i++) in[i].clear(); for(int i=0;i<=2*n;i+=2) { int f = rad[i]-1; in[f].push_back(i/2); } m = n; if(m%2==1) { m--; } for(int i=n;i>m/2;i--) { for(int j=0;j 1]); } } for(int i=m;i>=2;i-=2) { for(int j=0;j if(in[i][j] < s1.size()) add_a(in[i][j]); else if(in[i][j] > s1.size()) add_b(in[i][j]); } for(int j=0;j 2].size();j++) { unite(sa[query[i/2][j]],sa[query[i/2][j]+1]); } res += cur; } cout << res << endl; }