포스트

BOJ 32472 Simple Tree Decomposition Problem

sorohue가 PS하는 블로그

BOJ 32472 Simple Tree Decomposition Problem

문제 링크입니다.

나이브 DP 돌리기

$d[now][x]$를 정점 $now$의 서브트리를, $now$를 포함하며 크기가 $x$인 컴포넌트 하나와 크기가 $a$ 또는 $b$인 컴포넌트들로 분해하는 방법의 수로 정의합니다. 루트를 $1$번 정점으로 고정하면 답은 $d[1][0]$입니다.

DP 값을 관리할 때 $0 \le x < b$인 경우만 생각하는 것으로 충분합니다. $x \ge b$면 아래쪽부터 적당히 컴포넌트를 만들어 $x <b$로 만들거나, 빈 공간이 생겨 그것이 불가능한 경우뿐이기 때문입니다.

두 트리를 합칠 때마다, DP 값을 갱신하기 위해 두 트리에서의 가능한 $x$의 조합을 모두 고려해야 합니다. 대강 $d[now][x] = \sum_{i} (d[u][i]\times d[v][x-i])$ 꼴입니다. 이 과정에서 모든 $d[u][i]$와 $d[v][i]$의 순서쌍을 순회하므로, 하나의 트리를 합칠 때마다 총 $min(size_u , b)\times min(size_v , b)$ 회의 연산이 돌아갑니다. 트리를 합치는 이벤트는 $\mathcal{O}(N)$회 일어나므로, 시간 복잡도는 $\mathcal O(Nb^2)$입니다.

…만, 이대로 짜면 AC를 받을 수 있습니다!

나이브?

트리를 ETT를 이용해 구간으로 바꿔치면, 두 서브트리를 합치는 연산은 두 인접한 구간을 합치는 것과 같습니다. 두 구간을 합치기 위해 $min(size_u , b)\times min(size_v , b)$회의 연산이 돌아간다는 건, 왼쪽 구간의 마지막 $b$개 원소와 오른쪽 구간의 첫 $b$개 원소를 잇는 것과 같습니다.

이 행동은 궁극적으로 모든 원소마다 오른쪽으로 $2b$개 정도의 원소를 잇게 됩니다. 따라서 트리 DP를 돌리는 과정에서 필요한 총 연산 횟수는 $\mathcal{O}(Nb)$입니다. 제한 안에 돌 정도의 시간 복잡도지만, 상수 관리에 어느 정도 유의할 필요가 있습니다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pll = pair<ll, ll>;
const ll mod = 1e9+7;

ll d[101010][555]; int sz[101010]; int n, a, b;
vector<int> e[101010];

void dfs(int now, int pre){
	sz[now] = 1; d[now][1] = 1; if(a == 1) d[now][0] = 1;
	for(auto nxt : e[now]){
		if(nxt == pre) continue;
		dfs(nxt, now);
		ll full = 0;
		for(int sum = min(b, sz[now]+sz[nxt]); sum; sum--){
			ll tmp = 0;
			for(int i = min(sum, sz[now]); i >= max(1,sum-sz[nxt]); i--){
				int j = sum-i;
				tmp += d[now][i]*d[nxt][j]%mod;
				if((sum == b || sum == a) && j) full += d[now][i]*d[nxt][j]%mod;
			}
			d[now][sum] = tmp%mod;
		}
		d[now][0] *= d[nxt][0];
		d[now][0] += full%mod;
		d[now][0] %= mod;
		sz[now] += sz[nxt];
	}
}

int main(){
	cin.tie(0); cout.tie(0); ios::sync_with_stdio(false);
	cin >> n >> a >> b;
	for(int i = 1; i < n; i++){
		int u, v; cin >> u >> v;
		e[u].push_back(v);
		e[v].push_back(u);
	}
	dfs(1, 1);
	cout << d[1][0];
}
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.