포스트

BOJ 33953 PPC와 CPP 2

sorohue가 PS하는 블로그

BOJ 33953 PPC와 CPP 2

문제 링크입니다.

문제 요약

PC로만 이루어진 길이 $(K+1)N$의 문자열 중, $K$개의 연속한 P와 인접한 C를 삭제하는 연산을 반복해 빈 문자열로 만들 수 있는 문자열의 수를 구하세요.

33952. Cont

풀이의 흐름이 33952번 문제로부터 바로 이어집니다. 해당 문제를 풀지 않았다면 먼저 푸는 것을 추천드립니다. 해당 문제의 풀이는 여기에서 확인할 수 있습니다.

P의 개수가 C의 $K$배인 (부분)문자열만 고려합시다.

이전 문제의 풀이로부터 확장하면 어떤 문자열을 지울 수 있음과 문자열을 한쪽 끝이 C인 부분문자열들로 분해할 수 있음이 동치임을 알 수 있습니다.

문자열을 한쪽 끝이 C인 경우와 그렇지 않은 경우로 나누어 생각해 봅시다.

뭉탱이로 지우기

전체 문자열이 조건을 만족하는 경우 = 한쪽 끝이 C인 경우의 수를 구합시다. 양쪽 끝이 모두 C인 경우와 한쪽 끝만 C인 경우를 나누어 고려하면, 양쪽 끝이 고정된 상태에서 나머지 $(K+1)N-2$자리 안에 남은 C를 끼워넣는 방식으로 카운팅할 수 있으므로 그 값은 ${(K+1)N-2 \choose N-2} + 2{(K+1)N-2 \choose N-1}$임을 알 수 있습니다.

갈라서 지우기

문자열의 양 끝이 P인 지울 수 있는 문자열을 생각합시다. 양쪽 끝이 지워질 수 있으려면 각각 P..CC..P로 묶어야 하고, 그 사이에 적당한 문자열이 끼어들어가 있는 형태가 됩니다.

끼어들어가는 문자열을 C..C 꼴로 통일합시다. 포함-배제의 원리에 의해 전체 문자열의 개수는 끼워넣은 C..C가 0개일 때 - 1개일 때 + 2개일 때 - 3개일 때 … 로 계산할 수 있습니다.

문자열의 전체 길이가 정해져 있으므로 P..C, C..C, C..P 각각의 길이 합이 $(K+1)N$이 되도록 각각의 항을 곱할 수 있어야 합니다. 이는 각 문자열 쪼가리의 생성 함수를 FFT해 계산할 수 있습니다. 생성 함수의 $i$차항은 길이가 $(K+1)i$인 문자열의 개수를 나타내도록 하고 합성곱한 다항식의 $N$차항 계수를 꺼내면 충분합니다.

계산을 위해 C..C의 생성 함수 $F$에 대해 $1 - F + F^2 - F^3 + \cdots + (-F)^N$ 를 구해야 합니다. 이를 지수의 이진 표현을 바탕으로 분할 정복을 통해 계산할 수 있습니다. 구체적으로 $1-F$를 구한 뒤 $F$를 제곱해 $F^2$를 만들어 $(1-F)(1+F^2) = 1-F+F^2-F^3$을 만드는 식입니다.

이를 구현하면 총 시간 복잡도 ${\cal O}(N \lg^2 N)$에 문제를 해결할 수 있습니다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
const ll mod = 998244353;
const ll w = 3;

ll pw(ll n, ll r){
	ll ret = 1;
	while(r){
		if(r&1) ret = ret*n%mod;
		n = n*n%mod;
		r >>= 1;
	}
	return ret;
}

inline ll inv(ll n){return pw(n, mod-2);}

void ntt(vector<ll>& f, bool flag = 0){
	int n = f.size(), j = 0;
	vector<ll> root(n>>1);
	for(int i = 1; i < n; i++){
		int bit = n>>1;
		while(j >= bit){
			j -= bit; bit >>= 1;
		}
		j += bit;
		if(i < j) swap(f[i], f[j]);
	}
	ll ang = pw(w, (mod - 1) / n); if(flag) ang = inv(ang);
	root[0] = 1; for(int i=1; i<(n >> 1); i++) root[i] = root[i-1] * ang % mod;
	for(int i=2; i<=n; i<<=1){
		int step = n / i;
		for(int j=0; j<n; j+=i){
			for(int k=0; k<(i >> 1); k++){
				ll u = f[j | k], v = f[j | k | i >> 1] * root[step * k] % mod;
				f[j | k] = (u + v) % mod;
				f[j | k | i >> 1] = (u - v) % mod;
				if(f[j | k | i >> 1] < 0) f[j | k | i >> 1] += mod;
			}
		}
	}
	ll t = inv(n);
	if(flag) for(int i=0; i<n; i++) f[i] = f[i] * t % mod;
}

void mult(vector<ll>& a, vector<ll> b){
	int n = 2; while(n < a.size()+b.size()) n <<= 1;
	a.resize(n); b.resize(n); ntt(a); ntt(b);
	for(int i = 0; i < n; i++) a[i] = a[i]*b[i]%mod;
	ntt(a, 1); return;
}

ll fac[8383838], ifac[8383838];

inline ll C(int n, int r){
	return fac[n]*ifac[r]%mod*ifac[n-r]%mod;
}

ll MungTangE(int n, int k){
	return (C((k+1)*n-2, n-2)+2*C((k+1)*n-2, n-1))%mod;
}

vector<ll> PC, CC, CC_inex;

ll Gala(int n, int k){
	PC = {0}; for(int i = 1; i <= n; i++) PC.push_back(C((k+1)*i-2, i-1));
	mult(PC, PC); PC.resize(n+1);
	CC = {0,0};
	for(int i = 2; i <= n; i++) CC.push_back(C((k+1)*i-2, i-2));
	CC_inex = CC; for(auto& i : CC_inex) i = -i; CC_inex[0]++;
	for(int N = 2; N <= n; N <<= 1){
		mult(CC, CC); CC.resize(n+1); CC[0]++;
		mult(CC_inex, CC); if(CC_inex.size() > n+1) CC_inex.resize(n+1); CC[0]--;
	}
	mult(PC, CC_inex);
	return PC[n];
}

int main(){
	cin.tie(0);cout.tie(0);ios::sync_with_stdio(false);
	fac[0] = 1; for(ll i = 1; i <= 8282828; i++) fac[i] = fac[i-1]*i%mod;
	ifac[8282828] = pw(fac[8282828], mod-2);
	for(ll i = 8282828; i; i--) ifac[i-1] = ifac[i]*i%mod;
	int n, k; cin >> n >> k;
	cout << (MungTangE(n, k)+Gala(n, k))%mod;
}
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.