BOJ 33159 트리핑
sorohue가 PS하는 블로그
BOJ 33159 트리핑
문제 링크입니다.
문제 요약
정점 집합이 주어질 때 그 정점들까지의 거리 합이 최소인 정점을 찾는 쿼리를 처리하세요.
풀이
트리의 어떤 정점들을 선택한 상황에서 답이 될 수 있는 정점만을 추려내 봅시다. 적당히 루트를 잡아서 선택한 정점에서 올라가는 경로들을 고려하면, 여러 개의 경로가 합쳐지는 지점에 도달하기 전까지는 거리 합이 선형적으로 변화하기 때문에 중간 과정에서 답이 최소가 될 수 없음을 확인할 수 있습니다.
따라서 우리는 선택한 정점과 그 LCA들만 존재하는 압축된 트리에서의 답만 고려해도 충분합니다. 트리 압축은 DFS 순회 상 인접한 정점들의 LCA만 찾아 추가해 주는 것으로 ${cal O}(K \lg K)$에 가능합니다.
압축된 트리에서 문제를 해결하는 것은 ${\cal O}(K)$ 리루팅 DP로 가능합니다. 문제에 $\Sigma K$ 의 조건이 있기 때문에 전체 문제를 ${\cal O}(K \lg K)$에 해결하는 것으로 충분합니다.
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
int eidx[303030], depth[303030], par[20][303030], n, k, IDX;
vector<int> naive[303030], v;
vector<pii> e[303030];
ll ans, tmp, sz[303030];
void init(int now, int pre){
depth[now] = depth[pre]+1;
par[0][now] = pre;
eidx[now] = ++IDX;
for(auto& nxt : naive[now]) if(nxt != pre) init(nxt, now);
if(now == 1) for(int bit = 1; bit < 20; bit++) for(int i = 1; i <= n; i++) par[bit][i] = par[bit-1][par[bit-1][i]];
}
int getLCA(int u, int v){
if(depth[u] < depth[v]) swap(u, v);
for(int bit = 19; bit >= 0; bit--) if(depth[u]-(1<<bit) >= depth[v]) u = par[bit][u];
if(u == v) return u;
for(int bit = 19; bit >= 0; bit--) if(par[bit][u] != par[bit][v]){u = par[bit][u]; v = par[bit][v];}
return par[0][u];
}
void dfs(int now){
for(auto& [nxt, w] : e[now]){
dfs(nxt);
sz[now] += sz[nxt];
ans += sz[nxt]*w;
}
}
void solve(int now){
ans = min(ans, tmp);
for(auto& [nxt, w] : e[now]){
tmp -= w*sz[nxt];
tmp += w*(k-sz[nxt]);
solve(nxt);
tmp += w*sz[nxt];
tmp -= w*(k-sz[nxt]);
}
}
int main(){
cin.tie(0);cout.tie(0);ios::sync_with_stdio(false);
int q; cin >> n >> q; for(int i = 1; i < n; i++){
int u, v; cin >> u >> v; naive[u].push_back(v); naive[v].push_back(u);
} init(1,1);
while(q--){
cin >> k; for(int i = 0; i < k; i++){int x; cin >> x; sz[x] = 1; v.push_back(x);}
sort(v.begin(), v.end(), [&](int& a, int& b){return eidx[a] < eidx[b];});
for(int i = 1; i < k; i++) v.push_back(getLCA(v[i-1], v[i]));
sort(v.begin(), v.end(), [&](int& a, int& b){return eidx[a] < eidx[b];});
v.erase(unique(v.begin(), v.end()), v.end());
for(int i = 1; i < v.size(); i++){int L = getLCA(v[i-1], v[i]); e[L].push_back({v[i], depth[v[i]]-depth[L]});}
ans = 0; dfs(v[0]); tmp = ans; solve(v[0]); cout << ans << '\n';
for(auto& i : v){sz[i] = 0; e[i].clear();} v.clear();
}
}
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.