포스트

BOJ 32869 흑백조경사

sorohue가 PS하는 블로그

BOJ 32869 흑백조경사

문제 링크입니다.

Even한지 판별하기

간단한 트리 DP로 각 서브트리가 Even한지 판별할 수 있습니다.

  • 리프 노드는 Even합니다.
  • 루트 노드의 색이 서브트리에 속하는 정점들의 색에서 다수를 차지하지 않는다면 Even하지 않습니다.
  • Even하지 않은 노드가 서브트리에 껴 있으면 Even하지 않습니다.

루트를 적당히 1번으로 고정하고 문제를 푸는 것은 $\mathcal{O}(N)$에 할 수 있습니다.

루트 옮기기

흔히 리루팅이라 불리는 기법입니다.

트리에서 DFS를 돌면서 루트를 옮겨가면, 루트를 $u$에서 $u$와 인접한 정점인 $v$로 옮길 때 각 정점의 서브트리와 관련된 이런저런 값은 $u$와 $v$를 제외하고는 변경되지 않기 때문에, DFS 한 번으로 모든 정점이 루트일 때의 답을 만들어 낼 수 있습니다!

몇 가지 관찰을 통해 루트를 옮길 때 고려해야 할 사항들을 짚어냅시다.

  • Even하지 않은 자식이 있다면, 해당 자식으로 루트를 내리지 않았을 때의 트리는 모두 Even하지 않습니다. 이로부터 Even하지 않은 자식이 둘 이상 달려 있는 경우 그 서브트리에 트리가 Even해질 수 있는 정점이 없음을 알 수 있습니다.
  • 루트를 $u$에서 $v$로 내릴 때, $u$의 서브트리에는 $v$의 서브트리에 달려 있던 정점을 제외한 모든 정점이 달립니다.

각 정점 별로, 서브트리의 흰 정점 수, 검은 정점 수, Even한지의 여부, Even하지 않은 자식의 수를 첫 트리 DP에서 처리해 두고, 전체 트리의 각 색 정점 수와 구해 둔 값들을 이용해 리루팅을 돌리면 총 시간 복잡도 $\mathcal{O}(N)$에 문제를 해결할 수 있습니다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int, int>;
using piii = pair<int, pii>;

int c[202020], w[202020], b[202020], ecnt[202020], tot_w, tot_b;
bool even[202020];
vector<int> e[202020];
vector<int> ans;

void dfs(int now, int pre){
	even[now] = 1;
	if(c[now]) b[now]=1;
	else w[now]=1;
	for(auto nxt : e[now]){
		if(pre == nxt) continue;
		dfs(nxt,now);
		b[now] += b[nxt];
		w[now] += w[nxt];
		even[now] &= even[nxt];
		if(!even[nxt]) ecnt[now]++;
	}
	if(c[now] && (w[now] >= b[now]) || !c[now] && (w[now] <= b[now])) even[now] = 0;
}

void solve(int now, int pre){
	if(!ecnt[now] && (c[now] && (tot_w < tot_b) || !c[now] && (tot_w > tot_b))) ans.push_back(now);
	if(ecnt[now] >= 2) return;
	for(auto nxt : e[now]){
		if(nxt == pre) continue;
		if(ecnt[now] && even[nxt]) continue;
		if(c[now] && (tot_w-w[nxt] >= tot_b-b[nxt]) || !c[now] && (tot_w-w[nxt] <= tot_b-b[nxt])) continue;
		solve(nxt, now);
	}
}

int main(){
	cin.tie(0); cout.tie(0); ios::sync_with_stdio(false);
	int n; cin >> n;
	for(int i = 1; i <= n; i++){
		cin >> c[i];
		if(c[i]) tot_b++;
		else tot_w++;
	}
	for(int i = 1; i < n; i++){
		int u, v; cin >> u >> v;
		e[u].push_back(v);
		e[v].push_back(u);
	}
	dfs(1, 1);
	solve(1, 1);
	sort(ans.begin(), ans.end());
	cout << ans.size() << '\n';
	for(auto i : ans) cout << i << ' ';
}
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.