BOJ 7726 개꿀잼
sorohue가 PS하는 블로그
문제 링크입니다.
문제 요약
$N \times M$ 크기의 행렬에서 $K \le 3$개의 겹치지 않는 부분행렬을 골랐을 때의 최대 원소 합을 구하세요.
크게 자르기
제한 조건에 의해, 전체 행렬을 $K$개의 겹치지 않는 직사각형으로 빈틈없이 채울 수 있습니다. 이말인즉슨 $K$개의 부분행렬을 다음과 같은 방법으로 만들어낼 수 있다는 의미입니다.
- 전체 행렬을 $K$개의 겹치지 않는 직사각형으로 나누기
- 각 직사각형에서 $1$개의 최대 부분행렬 찾기
이때 전체 행렬을 $3$개의 직사각형으로 분할하는 방법은 다음의 4가지 뿐입니다.
- 세로로 3분할
- 가로로 먼저 자른 후, 두 조각 중 하나를 세로로 자르기
- 세로로 먼저 자른 후, 두 조각 중 하나를 가로로 자르기
- 가로로 3분할
행렬의 가로와 세로를 뒤집으면 뒤의 두 종류는 앞의 두 종류와 똑같은 방법으로 구해질 수 있을 것입니다. 따라서 앞의 두 경우만 구할 수 있으면 문제를 해결할 수 있습니다. $K = 1$일 때와 $K=2$일 때는 $K = 3$을 풀면서 자연스럽게 답이 구해집니다.
자동문
가로로 먼저 자른 후 세로로 자르기에서의 최댓값을 구해 봅시다. 가로로 $1$번밖에 자르지 않는다는 건, 이때 생기는 두 조각 중 하나는 천장에, 다른 하나는 바닥에 붙는다는 의미입니다. 그러니 천장에서부터 내려오면서, 바닥에서부터 올라오면서, 1개 또는 2개의 부분행렬을 세로로 만들었을 때의 최대 합을 모두 구하면 이 경우에 대한 답을 구할 수 있습니다. (그리고 이 과정에서 $K = 1$일 때와 $K = 2$일 때의 답도 구해집니다.)
세로로 3분할의 경우, 아까와 비슷하지만 천장과 바닥 대신 왼쪽과 오른쪽에 붙은 조각 안에서 1개/2개를 만들었을 때의 답을 전처리해야 합니다.
런타임 도중의 전처리
이제 구해야 하는 값들을 구해 봅시다. 천장에서 바닥으로 내려가면서, 다음의 값을 구할 것입니다.
- $d_1[i][l][r] :=$ 행렬의 $1$행부터 $i$행까지, $l$열부터 $r$열까지의 조각 안에서 만들 수 있는 부분행렬의 최대 합
부분행렬의 양 끝 열이 $l$, $r$로 고정되어 있다면 이는 단순한 DP 문제가 됩니다. 우리는 양 끝 열이 고정되어 있을 필요는 없으니, 구간 $[l, r]$에 포함되는 DP 값들 중 제일 큰 걸 들고 와야 합니다. 이는 구간이 작은 것부터 순서대로 계산하면서, $[l+1, r]$과 $[l, r-1]$에서의 값을 확인하는 것으로 충분합니다.
이렇게 하면 ${\cal O}(NM^2)$만에 모든 $d_1[i][l][r]$ 값을 구할 수 있습니다. 이제 이를 이용하면,
- $d_2[i] :=$ 행렬의 $1$행부터 $i$행까지의 조각 안에서 $2$개의 부분행렬을 세로로 배치했을 때의 최대 합
를 각 $i$마다 $d_1 [i][1][x]$와 $d_1[i][x+1][M]$을 합하는 것으로 ${\cal O}(NM)$에 계산할 수 있습니다.
마지막으로 3분할의 경우를 구하기 위해, 전체 행렬의 왼쪽/오른쪽부터 누적해 가며 $2$개의 부분행렬을 세로로 배치했을 때의 최댓값을 구하면 됩니다. 각각의 값을 총 ${\cal O}(M^2)$에 계산할 수 있고, 그 뒤 3분할의 경우의 답을 구하는 건 ${\cal O}(M)$에 가능합니다.
이 작업을 상하좌우 모두에 대해 총 4번 수행해준 뒤 답을 구해주면, 총 시간 복잡도는 ${\cal O}(NM(N+M))$가 됩니다.
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
using pll = pair<ll, ll>;
const ll INF = 123456123456LL;
int n, m, k;
ll a[333][333], mx[4][4][333][333], mxx[4][3][333], sum[333][333], sum_now[333][333], ans;
void solve(int dir){
mxx[dir][0][0] = mxx[dir][1][0] = mxx[dir][2][0] = -INF;
for(int t = 0; t <= 3; t++) for(int l = 0; l <= m; l++) for(int r = 0; r <= m; r++) mx[dir][t][l][r] = -INF;
for(int i = 1; i <= n; i++) for(int j = 1; j <= m; j++) sum[i][j] = sum[i][j-1]+a[i][j];
for(int i = 0; i <= m; i++) for(int j = 0; j <= m; j++) sum_now[i][j] = -INF;
for(int i = 1; i <= n; i++){
for(int delta = 0; delta < m; delta++){
for(int l = 1; l+delta <= m; l++){
sum_now[l][l+delta] = max(sum_now[l][l+delta], 0LL);
sum_now[l][l+delta] += sum[i][l+delta]-sum[i][l-1];
mx[dir][1][l][l+delta] = max(mx[dir][1][l][l+delta], sum_now[l][l+delta]);
if(delta >= 1) mx[dir][1][l][l+delta] = max({mx[dir][1][l][l+delta], mx[dir][1][l+1][l+delta], mx[dir][1][l][l+delta-1]});
}
}
for(int x = 1; x < m; x++) mxx[dir][2][i] = max(mxx[dir][2][i], mx[dir][1][1][x]+mx[dir][1][x+1][m]);
mxx[dir][1][i] = mx[dir][1][1][m];
}
for(int r = 2; r <= m; r++) for(int l = 2; l <= r; l++) mx[dir][2][1][r] = max(mx[dir][2][1][r], mx[dir][1][1][l-1]+mx[dir][1][l][r]);
for(int l = m-1; l >= 1; l--) for(int r = m-1; l <= r; r--) mx[dir][2][l][m] = max(mx[dir][2][l][m], mx[dir][1][l][r]+mx[dir][1][r+1][m]);
for(int i = 2; i < m; i++) mx[dir][3][1][m] = max(mx[dir][3][1][m], mx[dir][2][1][i]+mx[dir][1][i+1][m]);
for(int i = 1; i < m-1; i++) mx[dir][3][1][m] = max(mx[dir][3][1][m], mx[dir][1][1][i]+mx[dir][2][i+1][m]);
}
int main(){
cin.tie(0); cout.tie(0); ios::sync_with_stdio(false);
cin >> n >> m >> k;
for(int i = 1; i <= n; i++) for(int j = 1; j <= m; j++){cin >> a[i][j]; ans += a[i][j];}
solve(0); for(int i = 1; i <= n/2; i++) for(int j = 1; j <= m; j++) swap(a[i][j], a[n+1-i][j]); solve(1);
for(int i = 1; i <= max(n, m); i++) for(int j = 1; j < i; j++) swap(a[i][j], a[j][i]); swap(n, m);
solve(2);
for(int i = 1; i <= n/2; i++) for(int j = 1; j <= m; j++) swap(a[i][j], a[n+1-i][j]); solve(3); swap(n, m);
if(k == 1) return !(cout << mx[0][1][1][m]);
if(k == 2) return !(cout << max(mx[0][2][1][m], mx[2][2][1][n]));
ans = max({mx[0][3][1][m], mx[1][3][1][m], mx[2][3][1][n], mx[3][3][1][n]});
for(int i = 1; i < n; i++) ans = max({ans, mxx[0][1][i]+mxx[1][2][n-i], mxx[0][2][i]+mxx[1][1][n-i]});
for(int i = 1; i < m; i++) ans = max({ans, mxx[2][1][i]+mxx[3][2][m-i], mxx[2][2][i]+mxx[3][1][m-i]});
cout << ans;
}