BOJ 34822 함수동상 그래프
sorohue가 PS하는 블로그
문제 링크입니다.
문제 요약
함수형 그래프의 정점 몇 개 위에 동상이 있습니다. 동상을 간선을 따라 적절히 옮겨서 얻을 수 있는 동상이 올라간 정점 집합의 수를 구하세요.
풀이
컴포넌트가 여러 개인 경우 각 컴포넌트에 대한 경우의 수의 곱을 구하면 됩니다. 앞으로 모든 논의는 단일 컴포넌트만 고려합니다.
함수형 그래프는 컴포넌트 당 정확히 하나의 사이클을 갖습니다. 이 사이클에 집중합시다. 사이클 위에 $i$개의 정점이 올라가도록 동상을 밀어넣는 경우의 수를 구할 수 있다면, 사이클 위의 동상들은 사이클 위에서 원하는 대로 움직일 수 있으므로 이항 계수로 그 경우의 수를 쉽게 계산할 수 있습니다.
사이클을 제외한 함수형 그래프의 나머지 부분은 트리입니다. $d[x][i]$를 $x$번 정점의 서브트리에 $i$개의 동상을 남기는 경우의 수로 정의합니다. 이 DP는 부모 정점 방향으로 전파될 때 다항식 곱셈처럼 작동합니다. 부모 정점에서 봤을 때 $i$개의 동상이 남아있으려면, 그 서브트리의 합집합(부모 자신을 제외하고)에서는 합쳐서 $i$개 또는 $i-1$개의 동상이 나와야 함을 고려하면, 인덱스의 합이 $i$ 또는 $i-1$로 되어야 한다는 점에서 이런 인사이트를 찾을 수 있습니다.
아무튼 그래서 이 DP의 전파 시간은 나이브하게 전파할 때 서브트리의 크기 곱에 비례합니다. 이때 이러한 전파를 트리 전체에서 수행했을 때의 총 시간 복잡도는 ${\cal O}(N^2)$입니다(검은 돌 트릭으로 알려져 있습니다).
이를 통해 사이클마다 몇 개의 동상을 넣을 수 있는 경우의 수 또한 ${\cal O}(N^2)$에 구할 수 있어, 총 시간 복잡도 ${\cal O}(N^2)$에 문제를 해결할 수 있습니다.
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
const ll mod = 1e9+7;
const int MAXN = 8080;
const int LG = 13;
ll pw(ll n, ll r){
ll ret = 1;
while(r){
if(r&1) ret = ret*n%mod;
n = n*n%mod;
r >>= 1;
}
return ret;
}
ll fac[MAXN], inv[MAXN];
ll C(ll n, ll r){return fac[n]*inv[r]%mod*inv[n-r]%mod;}
int e[LG+1][MAXN]; bool statue[MAXN];
vector<int> rev[MAXN];
int incycle[MAXN]; int cyccnt;
vector<vector<vector<ll>>> components;
vector<int> cyclestatue;
vector<int> cyclesz(1);
vector<vector<ll>> d;
void mult(vector<ll>& a, vector<ll>& b){
vector<ll> ret(a.size()+b.size()-1);
for(int i = 0; i < a.size(); i++) for(int j = 0; j < b.size(); j++) ret[i+j] = (ret[i+j]+a[i]*b[j])%mod;
a = ret;
}
vector<ll>& dp(int now){
vector<ll>& ret = d[now];
ret = {1,1};
for(auto& nxt : rev[now]) mult(ret, dp(nxt));
if(!statue[now]) ret.pop_back();
return ret;
}
int main(){
cin.tie(0);cout.tie(0);ios::sync_with_stdio(false);
int n, k; cin >> n >> k;
for(int i = 1; i <= n; i++) cin >> e[0][i], rev[e[0][i]].push_back(i);
for(int i = 1; i <= k; i++){int v; cin >> v; statue[v] = 1;}
fac[0] = 1; for(int i = 1; i <= n; i++) fac[i] = fac[i-1]*i%mod;
inv[n] = pw(fac[n], mod-2); for(int i = n; i; i--) inv[i-1] = inv[i]*i%mod;
for(int bit = 1; bit <= LG; bit++) for(int i = 1; i <= n; i++) e[bit][i] = e[bit-1][e[bit-1][i]];
for(int i = 1; i <= n; i++) if(!incycle[e[LG][i]]){
cyccnt++; cyclesz.push_back(0); int now = e[LG][i]; while(!incycle[now]){
incycle[now] = cyccnt; now = e[0][now]; cyclesz.back()++;
}
}
d.resize(n+1);
components.resize(cyccnt+1); for(int i = 1; i <= cyccnt; i++) components[i].push_back({1});
cyclestatue.resize(cyccnt+1); for(int i = 1; i <= n; i++) if(incycle[i]) cyclestatue[incycle[i]] += statue[i];
for(int i = 1; i <= n; i++) if(!incycle[i] && incycle[e[0][i]]) components[incycle[e[0][i]]].push_back(dp(i));
ll ans = 1;
for(int i = 1; i <= cyccnt; i++){
vector<ll>& branch = components[i][0];
ll tmp = 0; for(int j = 1; j < components[i].size(); j++) mult(branch, components[i][j]);
reverse(branch.begin(), branch.end()); //im too lazy to recalculate index
for(int overflow = 0; overflow <= min((int)branch.size()-1, cyclesz[i]-cyclestatue[i]); overflow++){
tmp += C(cyclesz[i], cyclestatue[i]+overflow)*branch[overflow]%mod; tmp %= mod;
}
ans = ans*tmp%mod;
}
cout << ans;
}