포스트

BOJ 16216 우산

sorohue가 PS하는 블로그

BOJ 16216 우산

문제 링크입니다.

문제 요약

트리의 $1$번 정점에서 출발해, 특정 $k$개 정점 중 적절히 $i$ $( 1 \le i \le k)$개의 정점을 선택해 방문하는 최단 경로의 길이를 구하세요. 정점을 방문하는 순서는 임의로 정할 수 있고, $i$개의 정점을 모두 방문한 뒤 $1$번 정점으로 돌아오지 않아도 됩니다.

풀이

트리가 충분히 작다고 가정하고 문제를 해결해 봅시다. 각 서브트리에 대해서,

  1. 루트에서 출발해 $i$개의 특수한 정점을 방문한 후 루트로 돌아오는 최단 경로의 길이
  2. 루트에서 출발해 $i$개의 특수한 정점을 방문한 후 루트로 돌아오지 않는 최단 경로의 길이

를 모든 $i \le$ (서브트리 내의 특수한 정점 개수) 에 대해 트리 DP로 구할 수 있습니다. 이러한 형태의 DP는 ${\cal O}(Nk)$에 계산할 수 있습니다. (검은 돌 트릭 또는 Tree Optimization이라는 이름으로 알려져 있습니다.)

이 문제에서는 $N$의 제한이 $300\,000$으로 커서 위 풀이가 시간 안에 작동하지 않습니다. 하지만 $k$는 $5\,000$으로 꽤 작죠. 실제로 필요한 정점은 $k$개의 정점들과 $1$번 정점, 그리고 그 정점들로의 길이 구분되는 분기점에 해당하는 정점들 뿐입니다. 트리 압축 기법으로 ${\cal O}(k)$개의 정점만 남겨 압축하면 전체 문제를 ${\cal O}(k^2)$의 시간 복잡도로 해결할 수 있습니다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
const ll INF = 1234123412341234LL;
bool c[303030];
vector<int> ett;
vector<int> naive[303030];
vector<pll> e[303030];
int depth[303030], par[20][303030], nidx[303030], sz[303030], eidx[303030], IDX;
ll d[2][10101][5050]; //circuit/path, nidx, k

void dfs(int now, int pre){
	depth[now] = depth[pre]+1;
	par[0][now] = pre;
	eidx[now] = ++IDX;
	if(c[now]) ett.push_back(now);
	for(auto& nxt : naive[now]){
		if(nxt == pre) continue;
		dfs(nxt, now);
	}
}

void init_par(int n){
	for(int bit = 1; bit < 20; bit++){
		for(int i = 1; i <= n; i++) par[bit][i] = par[bit-1][par[bit-1][i]];
	}
}

int getLCA(int u, int v){
	if(depth[u] < depth[v]) swap(u, v);
	for(int bit = 19; bit >= 0; bit--) if(depth[u]-(1<<bit) >= depth[v]) u = par[bit][u];
	if(u == v) return u;
	for(int bit = 19; bit >= 0; bit--) if(par[bit][u] != par[bit][v]){
		u = par[bit][u]; v = par[bit][v];
	}
	return par[0][u];
}

void rerabel(int now){
	nidx[now] = ++IDX;
	for(auto& [nxt, w] : e[now]) rerabel(nxt);
}

void solve(int now){
	if(c[now]) sz[now] = 1;
	for(auto& [nxt, w] : e[now]){
		solve(nxt);
		for(int i = sz[now]+1; i <= sz[now]+sz[nxt]; i++) d[0][nidx[now]][i] = d[1][nidx[now]][i] = INF;
		sz[now] += sz[nxt];
		for(int tot = sz[now]; tot >= 1; tot--) for(int sub = min(tot,sz[nxt]); sub >= 1; sub--){
			if(tot-sub > sz[now]-sz[nxt]) break;
			d[0][nidx[now]][tot] = min(d[0][nidx[now]][tot], d[0][nidx[now]][tot-sub]+d[0][nidx[nxt]][sub]+2*w);
			d[1][nidx[now]][tot] = min(d[1][nidx[now]][tot], d[1][nidx[now]][tot-sub]+d[0][nidx[nxt]][sub]+2*w);
			d[1][nidx[now]][tot] = min(d[1][nidx[now]][tot], d[0][nidx[now]][tot-sub]+d[1][nidx[nxt]][sub]+w);
		}
	}
}

int main(){
	cin.tie(0);cout.tie(0);ios::sync_with_stdio(false);
	int n, k; cin >> n >> k;
	for(int i = 1; i < n; i++){
		int u, v; cin >> u >> v;
		naive[u].push_back(v);
		naive[v].push_back(u);
	}
	for(int i = 0; i < k; i++){
		int t; cin >> t; c[t] = 1;
	}
	ett.push_back(1); dfs(1,1); init_par(n); int S = ett.size();
	for(int i = 1; i < S; i++) ett.push_back(getLCA(ett[i-1], ett[i]));
	sort(ett.begin(), ett.end(), [&](int& a, int& b){
		return eidx[a] < eidx[b];
	});
	ett.erase(unique(ett.begin(), ett.end()), ett.end());
	for(int i = 1; i < ett.size(); i++){
		int L = getLCA(ett[i-1], ett[i]);
		e[L].push_back({ett[i], depth[ett[i]]-depth[L]});
	}
	IDX = 0; rerabel(1); solve(1);
	for(int i = 1; i <= k; i++) cout << d[1][1][i] << '\n';
}

이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.