포스트

BOJ 5813 이상적인 도시

sorohue가 PS하는 블로그

BOJ 5813 이상적인 도시

문제 링크입니다.

상특) 서브태스크 안 품

바로 만점 풀이로 갑시다. 아이디어 하나로 미는 타입의 문제라 서브태스크는 굳이 싶네요.

일단, $x$방향으로의 이동 거리와 $y$방향으로의 이동 거리를 분리할 수 있습니다. 그러니 $x$방향의 이동 거리만 생각해 봅시다. $y$ 방향으로는 아무렇게나 이동할 수 있다고 생각하면, 위아래로 붙어있는 막대기들을 단일 칸으로 압축할 수 있습니다.

그렇게 하면 도시가 구멍이 없는 한 덩어리라는 조건 때문에 압축된 도시가 트리 구조를 이룹니다. 와!

트리의 각 정점에 그 정점으로 압축된 칸의 수를 가중치로 주고 트리 DP를 짜면 모든 칸의 쌍에 대한 $x$방향으로의 이동 거리의 합을 구할 수 있습니다.

$y$좌표에 대해서는 모든 칸의 좌표를 뒤집고 같은 계산을 반복하는 것으로 해결할 수 있습니다.

압축에 set을 쓰기 때문에 시간 복잡도는 $\mathcal O(N \lg N)$이 됩니다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include<bits/stdc++.h>
#define x first
#define y second
using namespace std;
using ll = long long;
using pll = pair<ll, ll>;
using plll = tuple<ll, ll, ll>;
const ll mod = 1e9;

vector<set<int>> e;
map<pll, int> comp;
int par[101010];

int f(int x){
	return par[x] < 0 ? x : par[x] = f(par[x]);
}
void u(int x, int y){
	x = f(x); y = f(y);
	if(x == y) return;
	par[x] += par[y];
	par[y] = x;
}

plll dfs(int now, int pre){
	ll sum = 0, rtd = 0, cnt = -par[now];
	for(int nxt : e[now]){
		if(nxt == pre) continue;
		auto [nsum, nrtd, ncnt] = dfs(nxt, now);
		sum = (sum+nsum)%mod; sum = (sum+cnt*(nrtd+ncnt)%mod)%mod;
		sum = (sum+rtd*ncnt%mod)%mod;
		rtd = (rtd+nrtd+ncnt)%mod; cnt = (cnt+ncnt)%mod;
	}
	return {sum, rtd, cnt};
}

ll solve(vector<pll> p){
	memset(par, -1, sizeof(par));
	comp.clear();
	e.clear(); e.resize(p.size());
	for(int i = 0; i < p.size(); i++) comp[p[i]] = i;
	for(int i = 0; i < p.size(); i++){
		if(comp.count({p[i].x-1, p[i].y})) u(comp[p[i]], comp[{p[i].x-1, p[i].y}]);
	}
	for(int i = 0; i < p.size(); i++){
		if(comp.count({p[i].x, p[i].y-1})){
			e[f(comp[p[i]])].insert(f(comp[{p[i].x, p[i].y-1}]));
			e[f(comp[{p[i].x, p[i].y-1}])].insert(f(comp[p[i]]));
		}
	}
	auto [sum, rtd, cnt] = dfs(f(0), f(0));
	return sum;
}

int main(){
	cin.tie(0); cout.tie(0); ios::sync_with_stdio(false);
	int n; cin >> n; vector<pll> p(n);
	for(int i = 0; i < n; i++) cin >> p[i].x >> p[i].y;
	ll ans = solve(p);
	for(auto& a : p) swap(a.x, a.y);
	ans += solve(p);
	cout << ans%mod;
}
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.