포스트

BOJ 14958 Rock Paper Scissors

sorohue가 PS하는 블로그

BOJ 14958 Rock Paper Scissors

문제 링크입니다.

문제 요약

길이 $N$의 가위-바위-보 배열에 길이 $M (\le N)$의 가위-바위-보 배열을 완전히 들어가도록 겹칩니다. 겹쳐진 손 모양끼리 가위바위보를 진행할 때, 길이 $M$의 배열 쪽이 최대 몇 번 이길 수 있는지 구하세요.

${\cal O}(NM)$ 시간복잡도는 허용되지 않습니다.

풀이

0, 1, 2에서 중복을 허용해 두 원소를 뽑아 합을 구하면 각각 0, 1, 1, 2, 2, 2, 3, 3, 4 가 나옵니다. 세 경우는 모두 2가 나오고, 나머지 경우는 모두 $2 \mod 3$이 아닌 수만 나옴에 주목합시다.

각각의 가위-바위-보 배열을 $3N$차, $3M$차 다항식처럼 생각하고 3칸 단위로 하나의 손 모양을 표현하도록 합니다. 예를 들어서 가위는 $x^2$, 바위는 $x$, 보는 $1$인 식입니다. 두 다항식을 곱했을 때 $M$ 쪽이 이기는 경우가 항상 $3k+2$차항으로 모이게끔 두 배열에서의 매핑을 달리해 줍시다.

이제 같은 배치에서의 승부 결과가 한 항으로 모이게 만들어줘야 합니다. 이는 길이 $M$의 배열을 미리 뒤집은 뒤에 작업을 진행하면 해결할 수 있습니다.

이제 FFT로 다항식 곱셈을 처리하면 문제를 ${\cal O}((N+M)\lg(N+M))$의 시간 복잡도에 해결할 수 있습니다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include<bits/stdc++.h>
#define mid (l+r>>1)
using namespace std;
using ll = long long;
using pll = pair<ll, ll>;
const ll mod = 998244353; // =(119<<23)+1
const ll w = 3;

ll pw(ll n, ll r){
	ll ret = 1;
	while(r){
		if(r&1) ret = ret*n%mod;
		n = n*n%mod;
		r >>= 1;
	}
	return ret;
}

inline ll inv(ll n){return pw(n, mod-2);}

void ntt(vector<ll>& f, bool flag = 0){
	int n = f.size(), j = 0;
	vector<ll> root(n>>1);
	for(int i = 1; i < n; i++){
		int bit = n>>1;
		while(j >= bit){
			j -= bit; bit >>= 1;
		}
		j += bit;
		if(i < j) swap(f[i], f[j]);
	}
	ll ang = pw(w, (mod - 1) / n); if(flag) ang = inv(ang);
	root[0] = 1; for(int i=1; i<(n >> 1); i++) root[i] = root[i-1] * ang % mod;
	for(int i=2; i<=n; i<<=1){
		int step = n / i;
		for(int j=0; j<n; j+=i){
			for(int k=0; k<(i >> 1); k++){
				ll u = f[j | k], v = f[j | k | i >> 1] * root[step * k] % mod;
				f[j | k] = (u + v) % mod;
				f[j | k | i >> 1] = (u - v) % mod;
				if(f[j | k | i >> 1] < 0) f[j | k | i >> 1] += mod;
			}
		}
	}
	ll t = inv(n);
	if(flag) for(int i=0; i<n; i++) f[i] = f[i] * t % mod;
}

void mult(vector<ll>& a, vector<ll>& b){
	int n = 2; while(n < a.size()+b.size()) n <<= 1;
	a.resize(n); b.resize(n); ntt(a); ntt(b);
	for(int i = 0; i < n; i++) a[i] = a[i]*b[i]%mod;
	ntt(a, 1);
	return;
}

vector<ll> a, b;

int main(){
	cin.tie(0); cout.tie(0); ios::sync_with_stdio(false);
	int n, m; cin >> n >> m; a.resize(3*n); b.resize(3*m);
	string s, t; cin >> s >> t; reverse(t.begin(), t.end());
	for(int i = 0; i < n; i++){
		if(s[i] == 'R') a[3*i+2] = 1;
		if(s[i] == 'P') a[3*i+1] = 1;
		if(s[i] == 'S') a[3*i] = 1;
	}
	for(int i = 0; i < m; i++){
		if(t[i] == 'P') b[3*i] = 1;
		if(t[i] == 'S') b[3*i+1] = 1;
		if(t[i] == 'R') b[3*i+2] = 1;
	}
	mult(a, b); ll ans = 0;
	for(int i = 3*(m-1)+2; i < a.size(); i += 3) ans = max(ans, a[i]);
	cout << ans;
}
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.