BOJ 15481 그래프와 MST
sorohue가 PS하는 블로그
문제 링크입니다.
문제 요약
무향 가중치 연결 그래프 $G$가 주어집니다. $G$의 각 간선마다, 그 간선을 포함하는 스패닝 트리의 최소 가중치 합을 구하세요.
풀이
일단 $G$의 MST를 구하는 것에서부터 발상을 시작해 봅시다. 만약 간선이 이미 MST에 포함되어 있다면 그냥 MST의 가중치 합을 뱉으면 되니 넘어갑시다.
MST에 포함되어 있지 않은 간선이 두 정점 $u, v$를 잇는다고 생각해 봅시다. 그러면 해당 간선을 포함하는 스패닝 트리는 ($u$를 포함하는 부분그래프의 MST) - ($v$를 포함하는 부분 그래프의 MST) 꼴로 표현할 수 있습니다.
그런데 각 부분그래프의 MST는 전체 그래프의 MST의 서브트리일 수밖에 없습니다. 그게 아니면 원래 그래프에 더 나은 MST가 있다는 의미가 되니까요. 그래서 우리의 스패닝 트리는 MST에서 두 정점 $u, v$ 사이의 경로를 이루는 간선 중 하나를 지우고 $u$와 $v$를 잇는 간선을 끼워넣은 형태일 수밖에 없습니다.
트리에서 경로를 이루는 간선 가중치 중 최댓값을 빠르게 구하면 문제를 해결할 수 있습니다. LCA를 희소 배열로 구하는 방법을 응용해, 각 정점마다 부모 정점과 함께 자신과 부모 정점을 잇는 간선의 가중치를 저장합시다. 그 뒤 조상에게로 올라가는 경로 상에서의 간선 가중치 최댓값을 희소 배열로 저장해 주면, LCA를 구하면서 우리가 제거할 간선의 가중치도 빠르게 구할 수 있습니다.
이상을 구현하면 총 시간복잡도 ${\cal O}(M \lg N)$에 문제를 해결할 수 있습니다.
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pll = pair<ll, ll>;
vector<tuple<int,int,int>> e;
priority_queue<pll> pq;
vector<pll> mst[202020];
int par[20][202020], depth[202020];
ll max_edge[20][202020];
int uf_par[202020];
ll ans;
int Find(int x){
return uf_par[x] < 0 ? x : uf_par[x] = Find(uf_par[x]);
}
bool Union(int x, int y){
x = Find(x); y = Find(y);
if(x == y) return 0;
if(uf_par[x] > uf_par[y]) swap(x, y);
uf_par[x] += uf_par[y];
uf_par[y] = x;
return 1;
}
void dfs(int now, int pre){
for(auto [nxt, w] : mst[now]){
if(nxt == pre) continue;
par[0][nxt] = now;
max_edge[0][nxt] = w;
depth[nxt] = depth[now]+1;
dfs(nxt, now);
}
}
ll solve(int u, int v){
ll ret = 0;
if(depth[u] > depth[v]) swap(u, v);
for(int bit = 19; bit >= 0; bit--){
if(depth[u] <= depth[v]-(1<<bit)){
ret = max(ret, max_edge[bit][v]);
v = par[bit][v];
}
}
if(u == v) return ret;
for(int bit = 19; bit >= 0; bit--){
if(par[bit][u] != par[bit][v]){
ret = max({ret, max_edge[bit][u], max_edge[bit][v]});
u = par[bit][u];
v = par[bit][v];
}
}
ret = max({ret, max_edge[0][u], max_edge[0][v]});
return ret;
}
int main(){
cin.tie(0); cout.tie(0); ios::sync_with_stdio(false);
int n, m; cin >> n >> m;
for(int i = 0; i < m; i++){
int u, v, w; cin >> u >> v >> w;
e.emplace_back(u,v,w);
pq.push({-w, i});
}
memset(uf_par, -1, sizeof(uf_par));
while(pq.size()){
auto [dummy, i] = pq.top(); pq.pop();
auto [u, v, w] = e[i];
if(Union(u, v)){
mst[u].push_back({v, w});
mst[v].push_back({u, w});
ans += w;
}
}
par[0][1] = 1; max_edge[0][1] = 0;
dfs(1,1); for(int bit = 1; bit < 20; bit++){
for(int i = 1; i <= n; i++){
par[bit][i] = par[bit-1][par[bit-1][i]];
max_edge[bit][i] = max(max_edge[bit-1][i], max_edge[bit-1][par[bit-1][i]]);
}
}
for(int i = 0; i < m; i++){
auto [u, v, w] = e[i];
cout << ans+w-solve(u, v) << '\n';
}
}