BOJ 30513 하이퍼 삼각형 자르기
sorohue가 PS하는 블로그
문제 링크입니다.
제가 SASA Programming Contest 2023에 출제한 문제입니다.
원래는 제 친구가 2차원, 3차원에서의 문제를 들고 와서 내자고 한 문제였는데, OEIS에 검색해보니 각각 A002717과 A269747에 나와 있어서, 제가 냅다 더 높은 차원으로 문제를 확장시켰습니다.
관찰
$M$차원 하이퍼 삼각형을 주물러서, $M$개의 서로 다른 변이 각각 서로 수직하도록 만들어줍니다. 그러면 커다란 하이퍼 삼각형의 $M+1$개 꼭짓점을 각각 $M$차원 직교좌표공간에서의 좌표 $(0,0,\cdots,0),(N,0,\cdots,0),\cdots,(0,0,\cdots,N)$으로 표현할 수 있습니다. 직교좌표계의 각 단위벡터를 $x_1, x_2, \cdots , x_M$이라고 합시다. 그러면 $0 \le \Sigma x \le N$인 모든 정수점의 집합이 하이퍼 삼각형 조각의 가능한 모든 꼭짓점의 집합이 됩니다.
모든 하이퍼 삼각형 조각의 각 면은 $x_i = c$ 또는 $\Sigma x = c$ 에 있습니다. 각 c의 값은 한 변의 길이와 $M$개의 변이 서로 수직하는 꼭짓점의 좌표에 의해 결정됩니다.
따라서 하나의 꼭짓점을 기준으로, 각 변은 모두 양의 방향으로 뻗거나 모두 음의 방향으로 뻗습니다. 이를 이용해 각 꼭짓점마다 양의 방향과 음의 방향으로 뻗을 수 있는 거리를 합하는 것으로 문제를 변형할 수 있습니다.
음의 방향으로 뻗는 하이퍼 삼각형의 개수
꼭짓점의 좌표를 $(x_1, x_2, \cdots, x_M)$으로 두면, 음의 방향으로의 하이퍼 삼각형의 개수는 $\min(x_1, x_2, \cdots, x_M)$ 입니다. 즉 한 변의 길이가 $K$인 음의 방향으로의 하이퍼 삼각형을 만들기 위해서는 꼭짓점의 모든 좌표가 $K$ 이상이여야 합니다.
그러니 각 $K$마다 모든 좌표가 $K$ 이상인 점의 개수를 셉시다. 각 좌표마다 $K$씩 일단 값을 주고, 나머지 $N-MK$를 나눠주는 방법의 수를 세면 됩니다. 물론 값을 나눠주지 않거나 일부만 나눠줘도 되기 때문에, 값을 버리는 용도의 좌표를 하나 만들어서 생각합시다.
이 값을 중복조합을 이용해 표현하면 $_{M+1} H _{N-MK}$ 가 됩니다.
양의 방향으로 뻗는 하이퍼 삼각형의 개수
꼭짓점의 좌표를 $(x_1, x_2, \cdots, x_M)$로 두면, 양의 방향으로의 하이퍼 삼각형의 개수는 $N-\Sigma x$ 입니다.
그러니 이번에는 좌표의 합이 $S$ 인 점의 개수를 세고 거기에 $N-S$를 곱합시다. $S$를 $M$개의 좌표에 모두 나눠주는 방법의 수를 세면 됩니다. 점의 개수를 중복조합으로 표현하면 $_{M} H _{S}$ 가 됩니다.
중복조합 빠르게 계산하기
${n}H _{r} =\ _{n+r-1}C{r} = {(n+r-1)! \over r!(n-1)!}$ 입니다. 팩토리얼 값을 미리 전부 구해두면 빠르게 문제를 해결할 수 있습니다.
모듈러 위에서의 나눗셈은 모듈러가 소수이므로 모듈러 곱셈 역원을 쉽게 구할 수 있습니다.
역원까지 전처리한다면 총 시간 복잡도 $\mathcal{O}(N+M)$에 문제를 해결할 수 있습니다.
코드
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
const ll mod = 1e9+7;
ll fac[234567], inv[234567];
inline ll H(int n, int r){return fac[n+r-1]*inv[r]%mod*inv[n-1]%mod;}
ll pw(ll n, ll r){
ll ret = 1;
while(r){
if(r&1) ret = ret*n%mod;
n = n*n%mod;
r >>= 1;
}
return ret;
}
int main(){
cin.tie(0);cout.tie(0);ios::sync_with_stdio(0);
int m, n; cin >> m >> n;
fac[0] = 1;
for(int i = 1; i <= 202020; i++) fac[i] = fac[i-1]*i%mod;
inv[202020] = pw(fac[202020], mod-2);
for(int i = 202020; i; i--) inv[i-1] = inv[i]*i%mod;
ll ans = 0;
for(int s = 0; s < n; s++){
ans += H(m,s)*(n-s)%mod;
ans %= mod;
}
for(int k = 1; k*m <= n; k++){
ans += H(m+1,n-m*k);
ans %= mod;
}
cout << ans;
}