BOJ 32869 흑백조경사
sorohue가 PS하는 블로그
BOJ 32869 흑백조경사
문제 링크입니다.
Even한지 판별하기
간단한 트리 DP로 각 서브트리가 Even한지 판별할 수 있습니다.
- 리프 노드는 Even합니다.
- 루트 노드의 색이 서브트리에 속하는 정점들의 색에서 다수를 차지하지 않는다면 Even하지 않습니다.
- Even하지 않은 노드가 서브트리에 껴 있으면 Even하지 않습니다.
루트를 적당히 1번으로 고정하고 문제를 푸는 것은 $\mathcal{O}(N)$에 할 수 있습니다.
루트 옮기기
흔히 리루팅이라 불리는 기법입니다.
트리에서 DFS를 돌면서 루트를 옮겨가면, 루트를 $u$에서 $u$와 인접한 정점인 $v$로 옮길 때 각 정점의 서브트리와 관련된 이런저런 값은 $u$와 $v$를 제외하고는 변경되지 않기 때문에, DFS 한 번으로 모든 정점이 루트일 때의 답을 만들어 낼 수 있습니다!
몇 가지 관찰을 통해 루트를 옮길 때 고려해야 할 사항들을 짚어냅시다.
- Even하지 않은 자식이 있다면, 해당 자식으로 루트를 내리지 않았을 때의 트리는 모두 Even하지 않습니다. 이로부터 Even하지 않은 자식이 둘 이상 달려 있는 경우 그 서브트리에 트리가 Even해질 수 있는 정점이 없음을 알 수 있습니다.
- 루트를 $u$에서 $v$로 내릴 때, $u$의 서브트리에는 $v$의 서브트리에 달려 있던 정점을 제외한 모든 정점이 달립니다.
각 정점 별로, 서브트리의 흰 정점 수, 검은 정점 수, Even한지의 여부, Even하지 않은 자식의 수를 첫 트리 DP에서 처리해 두고, 전체 트리의 각 색 정점 수와 구해 둔 값들을 이용해 리루팅을 돌리면 총 시간 복잡도 $\mathcal{O}(N)$에 문제를 해결할 수 있습니다.
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int, int>;
using piii = pair<int, pii>;
int c[202020], w[202020], b[202020], ecnt[202020], tot_w, tot_b;
bool even[202020];
vector<int> e[202020];
vector<int> ans;
void dfs(int now, int pre){
even[now] = 1;
if(c[now]) b[now]=1;
else w[now]=1;
for(auto nxt : e[now]){
if(pre == nxt) continue;
dfs(nxt,now);
b[now] += b[nxt];
w[now] += w[nxt];
even[now] &= even[nxt];
if(!even[nxt]) ecnt[now]++;
}
if(c[now] && (w[now] >= b[now]) || !c[now] && (w[now] <= b[now])) even[now] = 0;
}
void solve(int now, int pre){
if(!ecnt[now] && (c[now] && (tot_w < tot_b) || !c[now] && (tot_w > tot_b))) ans.push_back(now);
if(ecnt[now] >= 2) return;
for(auto nxt : e[now]){
if(nxt == pre) continue;
if(ecnt[now] && even[nxt]) continue;
if(c[now] && (tot_w-w[nxt] >= tot_b-b[nxt]) || !c[now] && (tot_w-w[nxt] <= tot_b-b[nxt])) continue;
solve(nxt, now);
}
}
int main(){
cin.tie(0); cout.tie(0); ios::sync_with_stdio(false);
int n; cin >> n;
for(int i = 1; i <= n; i++){
cin >> c[i];
if(c[i]) tot_b++;
else tot_w++;
}
for(int i = 1; i < n; i++){
int u, v; cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1, 1);
solve(1, 1);
sort(ans.begin(), ans.end());
cout << ans.size() << '\n';
for(auto i : ans) cout << i << ' ';
}
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.