포스트

BOJ 34822 함수동상 그래프

sorohue가 PS하는 블로그

BOJ 34822 함수동상 그래프

문제 링크입니다.

문제 요약

함수형 그래프의 정점 몇 개 위에 동상이 있습니다. 동상을 간선을 따라 적절히 옮겨서 얻을 수 있는 동상이 올라간 정점 집합의 수를 구하세요.

풀이

컴포넌트가 여러 개인 경우 각 컴포넌트에 대한 경우의 수의 곱을 구하면 됩니다. 앞으로 모든 논의는 단일 컴포넌트만 고려합니다.

함수형 그래프는 컴포넌트 당 정확히 하나의 사이클을 갖습니다. 이 사이클에 집중합시다. 사이클 위에 $i$개의 정점이 올라가도록 동상을 밀어넣는 경우의 수를 구할 수 있다면, 사이클 위의 동상들은 사이클 위에서 원하는 대로 움직일 수 있으므로 이항 계수로 그 경우의 수를 쉽게 계산할 수 있습니다.

사이클을 제외한 함수형 그래프의 나머지 부분은 트리입니다. $d[x][i]$를 $x$번 정점의 서브트리에 $i$개의 동상을 남기는 경우의 수로 정의합니다. 이 DP는 부모 정점 방향으로 전파될 때 다항식 곱셈처럼 작동합니다. 부모 정점에서 봤을 때 $i$개의 동상이 남아있으려면, 그 서브트리의 합집합(부모 자신을 제외하고)에서는 합쳐서 $i$개 또는 $i-1$개의 동상이 나와야 함을 고려하면, 인덱스의 합이 $i$ 또는 $i-1$로 되어야 한다는 점에서 이런 인사이트를 찾을 수 있습니다.

아무튼 그래서 이 DP의 전파 시간은 나이브하게 전파할 때 서브트리의 크기 곱에 비례합니다. 이때 이러한 전파를 트리 전체에서 수행했을 때의 총 시간 복잡도는 ${\cal O}(N^2)$입니다(검은 돌 트릭으로 알려져 있습니다).

이를 통해 사이클마다 몇 개의 동상을 넣을 수 있는 경우의 수 또한 ${\cal O}(N^2)$에 구할 수 있어, 총 시간 복잡도 ${\cal O}(N^2)$에 문제를 해결할 수 있습니다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
const ll mod = 1e9+7;
const int MAXN = 8080;
const int LG = 13;
 
ll pw(ll n, ll r){
	ll ret = 1;
	while(r){
		if(r&1) ret = ret*n%mod;
		n = n*n%mod;
		r >>= 1;
	}
	return ret;
}
 
ll fac[MAXN], inv[MAXN];
ll C(ll n, ll r){return fac[n]*inv[r]%mod*inv[n-r]%mod;}
 
int e[LG+1][MAXN]; bool statue[MAXN];
vector<int> rev[MAXN];
int incycle[MAXN]; int cyccnt;
vector<vector<vector<ll>>> components;
vector<int> cyclestatue;
vector<int> cyclesz(1);
vector<vector<ll>> d;
 
void mult(vector<ll>& a, vector<ll>& b){
	vector<ll> ret(a.size()+b.size()-1);
	for(int i = 0; i < a.size(); i++) for(int j = 0; j < b.size(); j++) ret[i+j] = (ret[i+j]+a[i]*b[j])%mod;
	a = ret;
}
 
vector<ll>& dp(int now){
	vector<ll>& ret = d[now];
	ret = {1,1};
	for(auto& nxt : rev[now]) mult(ret, dp(nxt));
	if(!statue[now]) ret.pop_back();
	return ret;
}
 
int main(){
	cin.tie(0);cout.tie(0);ios::sync_with_stdio(false);
	int n, k; cin >> n >> k;
	for(int i = 1; i <= n; i++) cin >> e[0][i], rev[e[0][i]].push_back(i);
	for(int i = 1; i <= k; i++){int v; cin >> v; statue[v] = 1;}
	
	fac[0] = 1; for(int i = 1; i <= n; i++) fac[i] = fac[i-1]*i%mod;
	inv[n] = pw(fac[n], mod-2); for(int i = n; i; i--) inv[i-1] = inv[i]*i%mod;
	for(int bit = 1; bit <= LG; bit++) for(int i = 1; i <= n; i++) e[bit][i] = e[bit-1][e[bit-1][i]];
	
	for(int i = 1; i <= n; i++) if(!incycle[e[LG][i]]){
		cyccnt++; cyclesz.push_back(0); int now = e[LG][i]; while(!incycle[now]){
			incycle[now] = cyccnt; now = e[0][now]; cyclesz.back()++;
		}
	}
	
	d.resize(n+1);
	components.resize(cyccnt+1); for(int i = 1; i <= cyccnt; i++) components[i].push_back({1});
	cyclestatue.resize(cyccnt+1); for(int i = 1; i <= n; i++) if(incycle[i]) cyclestatue[incycle[i]] += statue[i];
 
	for(int i = 1; i <= n; i++) if(!incycle[i] && incycle[e[0][i]]) components[incycle[e[0][i]]].push_back(dp(i));
 
	ll ans = 1;
	for(int i = 1; i <= cyccnt; i++){
		vector<ll>& branch = components[i][0];
		ll tmp = 0; for(int j = 1; j < components[i].size(); j++) mult(branch, components[i][j]);
		reverse(branch.begin(), branch.end()); //im too lazy to recalculate index
		for(int overflow = 0; overflow <= min((int)branch.size()-1, cyclesz[i]-cyclestatue[i]); overflow++){
			tmp += C(cyclesz[i], cyclestatue[i]+overflow)*branch[overflow]%mod; tmp %= mod;
		}
		ans = ans*tmp%mod;
	}
	cout << ans;
}
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.