BOJ 14958 Rock Paper Scissors
sorohue가 PS하는 블로그
BOJ 14958 Rock Paper Scissors
문제 링크입니다.
문제 요약
길이 $N$의 가위-바위-보 배열에 길이 $M (\le N)$의 가위-바위-보 배열을 완전히 들어가도록 겹칩니다. 겹쳐진 손 모양끼리 가위바위보를 진행할 때, 길이 $M$의 배열 쪽이 최대 몇 번 이길 수 있는지 구하세요.
${\cal O}(NM)$ 시간복잡도는 허용되지 않습니다.
풀이
0, 1, 2에서 중복을 허용해 두 원소를 뽑아 합을 구하면 각각 0, 1, 1, 2, 2, 2, 3, 3, 4 가 나옵니다. 세 경우는 모두 2가 나오고, 나머지 경우는 모두 $2 \mod 3$이 아닌 수만 나옴에 주목합시다.
각각의 가위-바위-보 배열을 $3N$차, $3M$차 다항식처럼 생각하고 3칸 단위로 하나의 손 모양을 표현하도록 합니다. 예를 들어서 가위는 $x^2$, 바위는 $x$, 보는 $1$인 식입니다. 두 다항식을 곱했을 때 $M$ 쪽이 이기는 경우가 항상 $3k+2$차항으로 모이게끔 두 배열에서의 매핑을 달리해 줍시다.
이제 같은 배치에서의 승부 결과가 한 항으로 모이게 만들어줘야 합니다. 이는 길이 $M$의 배열을 미리 뒤집은 뒤에 작업을 진행하면 해결할 수 있습니다.
이제 FFT로 다항식 곱셈을 처리하면 문제를 ${\cal O}((N+M)\lg(N+M))$의 시간 복잡도에 해결할 수 있습니다.
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include<bits/stdc++.h>
#define mid (l+r>>1)
using namespace std;
using ll = long long;
using pll = pair<ll, ll>;
const ll mod = 998244353; // =(119<<23)+1
const ll w = 3;
ll pw(ll n, ll r){
ll ret = 1;
while(r){
if(r&1) ret = ret*n%mod;
n = n*n%mod;
r >>= 1;
}
return ret;
}
inline ll inv(ll n){return pw(n, mod-2);}
void ntt(vector<ll>& f, bool flag = 0){
int n = f.size(), j = 0;
vector<ll> root(n>>1);
for(int i = 1; i < n; i++){
int bit = n>>1;
while(j >= bit){
j -= bit; bit >>= 1;
}
j += bit;
if(i < j) swap(f[i], f[j]);
}
ll ang = pw(w, (mod - 1) / n); if(flag) ang = inv(ang);
root[0] = 1; for(int i=1; i<(n >> 1); i++) root[i] = root[i-1] * ang % mod;
for(int i=2; i<=n; i<<=1){
int step = n / i;
for(int j=0; j<n; j+=i){
for(int k=0; k<(i >> 1); k++){
ll u = f[j | k], v = f[j | k | i >> 1] * root[step * k] % mod;
f[j | k] = (u + v) % mod;
f[j | k | i >> 1] = (u - v) % mod;
if(f[j | k | i >> 1] < 0) f[j | k | i >> 1] += mod;
}
}
}
ll t = inv(n);
if(flag) for(int i=0; i<n; i++) f[i] = f[i] * t % mod;
}
void mult(vector<ll>& a, vector<ll>& b){
int n = 2; while(n < a.size()+b.size()) n <<= 1;
a.resize(n); b.resize(n); ntt(a); ntt(b);
for(int i = 0; i < n; i++) a[i] = a[i]*b[i]%mod;
ntt(a, 1);
return;
}
vector<ll> a, b;
int main(){
cin.tie(0); cout.tie(0); ios::sync_with_stdio(false);
int n, m; cin >> n >> m; a.resize(3*n); b.resize(3*m);
string s, t; cin >> s >> t; reverse(t.begin(), t.end());
for(int i = 0; i < n; i++){
if(s[i] == 'R') a[3*i+2] = 1;
if(s[i] == 'P') a[3*i+1] = 1;
if(s[i] == 'S') a[3*i] = 1;
}
for(int i = 0; i < m; i++){
if(t[i] == 'P') b[3*i] = 1;
if(t[i] == 'S') b[3*i+1] = 1;
if(t[i] == 'R') b[3*i+2] = 1;
}
mult(a, b); ll ans = 0;
for(int i = 3*(m-1)+2; i < a.size(); i += 3) ans = max(ans, a[i]);
cout << ans;
}
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.