BOJ 31114 Game Theory
sorohue가 PS하는 블로그
문제 링크입니다.
관찰
일단 아무 문자열이나 잡고 좀 건드려 봅시다.
01001011001 이 주어졌습니다. 이 문자열을 00000000000으로 만들어야 합니다.
처음에 1이 총 5개입니다. 5번째 글자가 0으로 바뀝니다. 01000011001
이제 1의 개수가 4개이므로, 4번째 글자가 1로 바뀝니다. 01010011001
마찬가지로 5번째, 6번째 글자가 1로 바뀝니다. 01011111001
그 뒤 7번째, 6번째, 5번째, 4번째 글자까지 순서대로 0으로 바뀝니다. 01000001001
가만히 보면 다음에 어떤 식으로 문자열이 바뀔지 예상할 수 있습니다. 3번째 글자에서 시작해 7번째 글자까지 1로 바뀌었다가, 8번째 글자부터 3번째 글자까지 다시 0으로 바뀌면서 되돌아오게 됩니다.
일반화
현재 문자열에 1이 총 $k$개라면, 그중 $i$번째 1의 위치를 $a_i$로 나타냅시다. 그러면 $1 \le a_1 < a_2 < \cdots < a_k$ 는 자연스럽습니다. 이때 문자열이 변화하는 패턴은 다음과 같습니다.
- 문자열의 $k$번째 비트부터 시작해서, 첫
1이 나올 때까지 오른쪽으로 진행합니다. - 첫
1을 만나면, $k-1$번째 비트가 나올 때까지 왼쪽으로 진행합니다.
이 과정을 한 번 거치고 나면, 처음 문자열에서 $k$번째 이상의 첫 1만 뒤집어지고 나머지는 그대로인 문자열이 얻어집니다.
처음으로 $a_i \ge k$를 만족하는 $i$의 값을 $t$로 두면, 패턴을 한 번 도는 데 필요한 행동의 수는 $2(a_t-k)+1$입니다. 문자열의 모든 비트가 0이 될 때까지 이 값을 다 더해주면 문제에서 구해야 하는 $f$ 값을 얻을 수 있습니다.
한 패턴이 끝나면 그때의 $a_t$는 0으로 뒤집어지기 때문에, 전체 과정 중에는 $a_1$부터 $a_k$까지의 모든 원소가 정확히 한 번씩 $a_t$ 자리에 들어갑니다. 그리고 $k$ 자리에는 $k$부터 $1$까지의 값이 한 번씩 들어갑니다. 따라서 이를 모두 더하면 아래의 식이 성립합니다.
쿼리 처리하기
이제 1의 위치의 합과 1의 개수를 세기만 하면 답을 구할 수 있습니다. 주어진 구간의 비트를 모두 뒤집는 쿼리를 처리해야 하므로, 느리게 갱신되는 세그먼트 트리를 이용하면 쿼리 당 $\mathcal{O}(\log N)$에 답을 구할 수 있습니다. 총 시간 복잡도는 $\mathcal{O} ((N+Q) \log N)$ 입니다.
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include<bits/stdc++.h>
#define mid (l+r>>1)
using namespace std;
using ll = long long;
using pll = pair<ll, ll>;
constexpr ll mod = 998244353;
ll cnt[808080], sum[808080];
bool lazy[808080];
inline ll MD(ll x){
return (x%mod+2*mod)%mod;
}
void prop(int now, ll l, ll r){
if(!lazy[now]) return;
cnt[now] = MD(r-l+1-cnt[now]);
sum[now] = MD((l+r)*(r-l+1)/2-sum[now]);
if(l < r){
lazy[now<<1] ^= 1;
lazy[now<<1|1] ^= 1;
}
lazy[now] = 0;
}
void upd(int now, int l, int r, int L, int R){
prop(now, l, r);
if(l > R || L > r) return;
if(L <= l && r <= R){
lazy[now] ^= 1; prop(now, l, r); return;
}
upd(now<<1, l, mid, L, R); upd(now<<1|1, mid+1, r, L, R);
cnt[now] = MD(cnt[now<<1]+cnt[now<<1|1]);
sum[now] = MD(sum[now<<1]+sum[now<<1|1]);
}
string s;
void init(int now, ll l, ll r){
lazy[now] = 0;
if(l == r){
cnt[now] = (s[l-1]-'0');
sum[now] = l*cnt[now];
return;
}
init(now<<1, l, mid); init(now<<1|1, mid+1, r);
cnt[now] = MD(cnt[now<<1]+cnt[now<<1|1]);
sum[now] = MD(sum[now<<1]+sum[now<<1|1]);
}
int main(){
cin.tie(0); cout.tie(0); ios::sync_with_stdio(false);
int n, q; while(cin >> n >> q){
cin >> s; init(1, 1, n);
while(q--){
int L, R; cin >> L >> R;
upd(1, 1, n, L, R);
cout << MD(2*sum[1]-cnt[1]*cnt[1]) << '\n';
}
}
}