포스트

BOJ 15481 그래프와 MST

sorohue가 PS하는 블로그

BOJ 15481 그래프와 MST

문제 링크입니다.

문제 요약

무향 가중치 연결 그래프 $G$가 주어집니다. $G$의 각 간선마다, 그 간선을 포함하는 스패닝 트리의 최소 가중치 합을 구하세요.

풀이

일단 $G$의 MST를 구하는 것에서부터 발상을 시작해 봅시다. 만약 간선이 이미 MST에 포함되어 있다면 그냥 MST의 가중치 합을 뱉으면 되니 넘어갑시다.

MST에 포함되어 있지 않은 간선이 두 정점 $u, v$를 잇는다고 생각해 봅시다. 그러면 해당 간선을 포함하는 스패닝 트리는 ($u$를 포함하는 부분그래프의 MST) - ($v$를 포함하는 부분 그래프의 MST) 꼴로 표현할 수 있습니다.

그런데 각 부분그래프의 MST는 전체 그래프의 MST의 서브트리일 수밖에 없습니다. 그게 아니면 원래 그래프에 더 나은 MST가 있다는 의미가 되니까요. 그래서 우리의 스패닝 트리는 MST에서 두 정점 $u, v$ 사이의 경로를 이루는 간선 중 하나를 지우고 $u$와 $v$를 잇는 간선을 끼워넣은 형태일 수밖에 없습니다.

트리에서 경로를 이루는 간선 가중치 중 최댓값을 빠르게 구하면 문제를 해결할 수 있습니다. LCA를 희소 배열로 구하는 방법을 응용해, 각 정점마다 부모 정점과 함께 자신과 부모 정점을 잇는 간선의 가중치를 저장합시다. 그 뒤 조상에게로 올라가는 경로 상에서의 간선 가중치 최댓값을 희소 배열로 저장해 주면, LCA를 구하면서 우리가 제거할 간선의 가중치도 빠르게 구할 수 있습니다.

이상을 구현하면 총 시간복잡도 ${\cal O}(M \lg N)$에 문제를 해결할 수 있습니다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pll = pair<ll, ll>;

vector<tuple<int,int,int>> e;
priority_queue<pll> pq;
vector<pll> mst[202020];
int par[20][202020], depth[202020];
ll max_edge[20][202020];
int uf_par[202020];
ll ans;

int Find(int x){
	return uf_par[x] < 0 ? x : uf_par[x] = Find(uf_par[x]);
}

bool Union(int x, int y){
	x = Find(x); y = Find(y);
	if(x == y) return 0;
	if(uf_par[x] > uf_par[y]) swap(x, y);
	uf_par[x] += uf_par[y];
	uf_par[y] = x;
	return 1;
}

void dfs(int now, int pre){
	for(auto [nxt, w] : mst[now]){
		if(nxt == pre) continue;
		par[0][nxt] = now;
		max_edge[0][nxt] = w;
		depth[nxt] = depth[now]+1;
		dfs(nxt, now);
	}
}

ll solve(int u, int v){
	ll ret = 0;
	if(depth[u] > depth[v]) swap(u, v);
	for(int bit = 19; bit >= 0; bit--){
		if(depth[u] <= depth[v]-(1<<bit)){
			ret = max(ret, max_edge[bit][v]);
			v = par[bit][v];
		}
	}
	if(u == v) return ret;
	for(int bit = 19; bit >= 0; bit--){
		if(par[bit][u] != par[bit][v]){
			ret = max({ret, max_edge[bit][u], max_edge[bit][v]});
			u = par[bit][u];
			v = par[bit][v];
		}
	}
	ret = max({ret, max_edge[0][u], max_edge[0][v]});
	return ret;
}

int main(){
	cin.tie(0); cout.tie(0); ios::sync_with_stdio(false);
	int n, m; cin >> n >> m;
	for(int i = 0; i < m; i++){
		int u, v, w; cin >> u >> v >> w;
		e.emplace_back(u,v,w);
		pq.push({-w, i});
	}
	memset(uf_par, -1, sizeof(uf_par));
	while(pq.size()){
		auto [dummy, i] = pq.top(); pq.pop();
		auto [u, v, w] = e[i];
		if(Union(u, v)){
			mst[u].push_back({v, w});
			mst[v].push_back({u, w});
			ans += w;
		}
	}
	par[0][1] = 1; max_edge[0][1] = 0;
	dfs(1,1); for(int bit = 1; bit < 20; bit++){
		for(int i = 1; i <= n; i++){
			par[bit][i] = par[bit-1][par[bit-1][i]];
			max_edge[bit][i] = max(max_edge[bit-1][i], max_edge[bit-1][par[bit-1][i]]);
		}
	}
	for(int i = 0; i < m; i++){
		auto [u, v, w] = e[i];
		cout << ans+w-solve(u, v) << '\n';
	}
}
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.