포스트

BOJ 21843 Game Show

sorohue가 PS하는 블로그

BOJ 21843 Game Show

문제 링크입니다.

문제 요약

세 팀이 $3N$개의 문제 중 $N$개씩을 맡아 풉니다. A팀이 먼저 $N$개의 문제를 고르고, B팀이 남은 문제 중 $N$개를 골라 가져갑니다. 남은 문제는 C팀이 가져갑니다.

문제마다 팀 별로 그 문제를 풀 확률이 정해져 있습니다. A팀은 자신이 푼 문제 수에서 각 팀이 푼 평균 문제 수를 뺀 만큼 상금을 얻습니다.

B팀이 A팀의 상금 기댓값을 최소화하는 전략을 사용할 때 A팀의 상금 기댓값을 최대화하세요.

B의 전략

B의 목표는 B가 푼 문제 수 + C가 푼 문제 수의 기댓값을 최대화하는 것입니다. C팀에게 자신이 푸는 것보다 C팀 쪽에서 푸는 게 기댓값이 더 높아지는 문제들 위주로 보내는 게 최선일 것임을 직관적으로 알 수 있습니다. 이를 다시 쓰면, $2N$개의 문제를 $p_B - p_C$가 큰 순으로 정렬해 앞쪽 $N$개 문제를 B팀이, 뒤쪽 $N$개 문제를 C팀이 맡도록 배분하는 것이 B팀의 최적 전략입니다.

A의 전략

A가 $2N$개의 문제를 남기면 B의 전략에 따라 각 문제가 어느 팀에게 갈지 알 수 있습니다. 즉 A는 B와 C에게 각각 $N$문제를 배당해 주고 남는 $N$개를 가져가는 식으로 문제를 고를 수 있습니다.

이를 위해서는 $p_B - p_C$ 순으로 문제를 정렬했을 때 앞에서 $N$개, 뒤에서 $N$개를 골라내야 합니다. 이를 앞쪽 $i \ge N$개 중 $N$개를 골라 $B$에게 넘겨주는 행동과 뒤쪽 $3N-i \ge N$개 중 $N$개를 골라 $C$에게 넘겨주는 행동으로 분할하면, 각 경우의 최댓값은 우선순위 큐 등을 이용해 incremental하게 구할 수 있습니다. 구한 값들을 $i$에 따라 합쳐 그중 최댓값을 취하면 정답을 얻을 수 있습니다.

총 시간 복잡도는 ${\cal O}(N \lg N)$입니다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include<bits/stdc++.h>
using namespace std;
using ld = long double;

vector<array<ld, 3>> a;
priority_queue<ld> pq;
vector<ld> b, c;
ld ans = -1e18, now;

int main(){
	cin.tie(0);cout.tie(0);ios::sync_with_stdio(false);
	int n; cin >> n; a.resize(3*n); for(int j = 0; j < 3; j++) for(int i = 0; i < 3*n; i++) cin >> a[i][j];
	//bigger b-c goes to b, smaller b-c goes to c
	sort(a.begin(), a.end(), [&](array<ld, 3>& a, array<ld, 3>& b){
		return a[1]-a[2] > b[1]-b[2];
	});
	
	for(int i = 0; i < n; i++) pq.push(a[i][0]*2+a[i][1]), now -= a[i][1];
	b.push_back(now);
	for(int i = 0; i < n; i++){
		pq.push(a[i+n][0]*2+a[i+n][1]); now -= a[i+n][1]-pq.top(); pq.pop();
		b.push_back(now);
	}
	while(pq.size()) pq.pop(); now = 0;
	for(int i = 1; i <= n; i++) pq.push(a[3*n-i][0]*2+a[3*n-i][2]), now -= a[3*n-i][2];
	c.push_back(now);
	for(int i = 1; i <= n; i++){
		pq.push(a[2*n-i][0]*2+a[2*n-i][2]); now -= a[2*n-i][2]-pq.top(); pq.pop();
		c.push_back(now);
	}
	for(int i = 0; i <= n; i++) ans = max(ans, b[i]+c[n-i]);
	cout << fixed << setprecision(12) << ans*1000/3;
}
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.