포스트

BOJ 33807 Minimum Spanning Arborescence

sorohue가 PS하는 블로그

BOJ 33807 Minimum Spanning Arborescence

문제 링크입니다.

문제 요약

DAG의 각 간선에 $1$ 이상 $K$ 이하의 가중치를 무작위로 부여했을 때 MST(비슷한 무언가)를 구성하는 간선의 가중치 합의 기댓값을 구하세요.

풀이

문제에서 주어지는 그래프는 DAG이므로 정점을 위상 정렬할 수 있습니다. 스패닝 트리에 위상 정렬 순서대로 정점을 추가한다고 생각하면, 각 정점을 추가하기 전에 모든 부모 정점이 스패닝 트리에 먼저 추가됩니다. 따라서 각 정점을 스패닝 트리에 추가할 때 자신과 직접 연결되어 있는 진입 간선만 전부 고려해 주면 충분합니다. 각 간선이 어느 정점으로부터 나왔는지는 몰라도 됩니다! 게다가 각 간선의 가중치를 결정하는 방법이 모두 동일하므로, 기댓값을 구함에 있어서 중요한 것은 오직 진입 간선의 개수 뿐임을 알 수 있습니다.

Only the indegree is important!

기댓값의 선형성에 의해, 각 정점 별로 스패닝 트리에 추가했을 때의 가중치의 기댓값을 구한 뒤 모두 더하는 것으로 전체 스패닝 트리의 가중치 합의 기댓값을 구할 수 있어요.

어떤 정점 $v$의 진입 간선이 $i$개라고 생각합시다. 이때 스패닝 트리에 $v$를 추가할 때 더해지는 간선 가중치의 기댓값을 $E_i$라고 하겠습니다. $i$개의 간선 중 가중치의 최솟값이 $K-t+1$인 경우의 수는, $i$개의 간선이 각각 $K-t+1$부터 $K$까지 $t$가지 가중치 중 하나를 갖는 경우의 수에서 $K-t+2$부터 $K$까지 $t-1$개의 가중치 중 하나를 갖는 경우의 수를 뺀 값과 같습니다. 따라서 $E_i = \sum _{t = 1} ^{K} {t^i - (t-1)^i \over K^i} (K-t+1)$입니다. 이 식을 정리합시다.

$\sum _{t = 1} ^{K} {t^i - (t-1)^i \over K^i} (K-t+1) = (K+1)\sum _{t=1} ^{K} {t^i - (t-1)^i \over K^i} - \sum _{t=1} ^{K} {t^i - (t-1)^i \over K^i}t$ 입니다.

이때 $\sum _{t=1} ^{K} {t^i - (t-1)^i \over K^i}$는 가능한 모든 최솟값이 나올 확률의 합이므로 $1$임을 알 수 있습니다.

한편, $\sum _{t=1} ^{K} {t^i - (t-1)^i \over K^i}t = \sum _{t=1} ^{K} {t^i \over K^i}t - \sum _{t=1} ^{K} {(t-1)^i \over K^i}t = \sum _{t=1} ^{K} {t^i \over K^i}t - \sum _{t=1} ^{K} {(t-1)^i \over K^i}(t-1) - \sum _{t=1} ^{K} {(t-1)^i \over K^i}$이고,

$\sum _{t=1} ^{K} {t^i \over K^i}t - \sum _{t=1} ^{K} {(t-1)^i \over K^i}(t-1) = {K^i \over K^i}K = K$이므로 전체 식을 다음과 같이 나타낼 수 있습니다.

$E_i = K+1- \left( {K-\sum _{t=1}^K {(t-1)^i\over K^i}} \right) = 1+\sum _{t=1}^K {(t-1)^i\over K^i}$

이제 $\sum _{t=1}^K {(t-1)^i\over K^i}$의 값을 구하면 $E_i$의 값을 알아낼 수 있습니다. 그냥 계산하면 ${\cal O}(K \lg i)$의 시간이 걸리므로 총 시간 복잡도는 ${\cal O}(NK lg M)$가 되겠네요.

사실 $M$개의 간선을 최대한 다양한 진입 차수가 나오도록 분배해도 그 종류가 많아야 ${\cal O}(\sqrt{M})$개임을 알 수 있습니다. $0+1+2+\cdots$와 같은 형태로 진입 차수를 구성하는 게 최선인데, 이때 필요한 간선 수가 제곱 스케일로 커지기 때문입니다.

따라서 필요할 때마다 합을 구해 저장하는 방식을 쓰면 시간 복잡도가 ${\cal O}(N+\sqrt{M} K lg M)$이 됩니다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int, int>;
const ll mod = 998244353;

ll pw(ll n, ll r){
	ll ret = 1;
	while(r){
		if(r&1) ret = ret*n%mod;
		n = n*n%mod;
		r >>= 1;
	}
	return ret;
}

vector<int> e[101010];
int deg[101010];
ll sums[101010];

int main(){
	cin.tie(0);cout.tie(0);ios::sync_with_stdio(false);
	int n, m, k; cin >> n >> m >> k;
	while(m--){
		int u, v; cin >> u >> v; e[u].push_back(v); deg[v]++;
	}
	ll ans = 0;
	for(int i = 2; i <= n; i++){
		if(sums[deg[i]] == 0){
			ll inv = pw(pw(k, deg[i]), mod-2);
			ll tmp = 0;
			for(ll j = 1; j < k; j++) tmp = (tmp+pw(j, deg[i]))%mod;
			sums[deg[i]] = tmp*inv%mod;
		}
		ans = (ans+1+sums[deg[i]]+mod)%mod;
	}
	cout << ans;
}
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.