-
백준 18779 풀이 (Help Yourself)문제풀이 2021. 6. 15. 14:02
문제 링크: https://www.acmicpc.net/problem/18779
문제 설명: [1, 2N] 구간에 선분이 N개 있을 때, 2^N개의 모든 선분의 부분집합에 대해 (연결성분의 개수)^K의 합을 MOD로 나눈 나머지를 출력하는 문제입니다.
풀이)
일단 K=1인 경우부터 해결해봅시다. K=1인 경우는 다음 문제와 동일합니다.
https://www.acmicpc.net/problem/18781
저는 이 문제를 2가지 풀이로 풀었는데 모두 설명하도록 하겠습니다. 이 중 2번째 풀이를 확장해서 원래 문제를 풀 것입니다.
풀이 1)
K=1이므로 각 연결성분의 개수를 세주면 됩니다. [1, 2N] 구간에 있는 점 i에 대해, i에서 끝나는 연결성분을 가지는 부분집합의 개수의 합을 구하면 답이 됨을 알 수 있습니다. 이때, i를 포함하는 선분은 부분집합의 원소가 되어선 안 되고, 그렇지 않은 선분은 따로 고려할 필요가 없기 때문에, 각 선분의 끝점마다 2^{N-(i를 포함하는 선분 개수)}를 전부 더해주면 됩니다.
선분의 시작점을 +1, 끝점을 -1로 놓고, 현재까지의 누적합을 계산하면 현재 있는 점이 선분의 끝점일 경우, 이 값은 i를 포함하는 선분 개수와 같아짐을 알 수 있습니다. 따라서, 선분의 끝점에 도달할 때마다 2^{N-(누적합)}을 답에 더해주면 문제가 풀립니다.
시간복잡도는 O(N)입니다.
구현:
더보기#include <bits/stdc++.h> typedef long long ll; using namespace std; const int MOD = 1e9+7; int a[200200], pw[200200]; int main(){ int n; scanf("%d", &n); for (int i=0;i<n;i++){ int x, y; scanf("%d %d", &x, &y); a[x]++; a[y]--; } pw[0] = 1; for (int i=1;i<=n;i++) pw[i] = (pw[i-1]<<1)%MOD; int cur = 0, ans = 0; for (int i=1;i<=(n<<1);i++){ if (a[i]==-1){ ans += pw[n-cur]; if (ans>=MOD) ans -= MOD; } cur += a[i]; } printf("%d\n", ans); return 0; }
풀이 2)
다이나믹 프로그래밍의 관점에서 접근해봅시다. 각 선분들을 시작점 기준으로 정렬해놓은 상태라고 가정합시다.
i번째 선분의 시작지점과 끝점을 각각 s_i, e_i라고 합시다.
그리고 dp[i] = (1번째 선분부터 i번째 선분까지만 존재할 때 답) 이라고 정의합시다.
그러면 다음 점화식이 성립합니다.
dp[i] = dp[i-1]*2 + (합집합의 원소 중 최댓값이 s_i보다 작은 부분집합의 개수)
기존에 i-1번째 선분까지만 있을 때의 답은 dp[i-1]이고, 여기에 i번째 선분을 추가했을 때 연결성분의 개수는 0 또는 1 증가합니다. 1이 증가하는 경우는 선택한 부분집합의 원소들이 나타내는 영역의 원소 중 최댓값이 s_i보다 작은 경우뿐이므로, 위 점화식이 성립함을 알 수 있습니다.
위 점화식을 계산하기 위해 세그먼트 트리를 하나 만듭시다. 세그먼트 트리의 i번째 리프에는 합집합의 원소 중 최댓값이 정확히 i인 부분집합의 개수를 저장하고, 각 노드에는 자식노드의 합을 저장합시다. 그러면 점화식을 O(logN)에 계산할 수 있고, 각 단계마다 세그먼트 트리를 업데이트해주면 됩니다.
[0, e_i]구간에서 끝나는 부분집합에 i번째 선분을 추가할 경우 끝지점이 e_i가 되고, [e_i+1, 2N]구간에서 끝나는 부분집합에 i번째 선분을 추가하면 끝지점이 변하지 않습니다.
따라서, 각 단계마다 e_i번째 리프의 값에 [0, e_i] 합+1을 더해주고(공집합에 추가하는 경우 때문에 +1), [e_i+1, 2N]번째 리프의 값을 전부 2배해주면 세그먼트 트리를 갱신할 수 있습니다.
lazy propagation을 사용하면 O(NlogN)에 구현할 수 있습니다.
구현:
더보기#include <bits/stdc++.h> using namespace std; typedef long long ll; pair<int, int> xy[100100]; ll tree[800400], lazy[800400], pw[100100], MOD=1e9+7; void propagate(int i, int l, int r){ tree[i]=(tree[i]*pw[lazy[i]])%MOD; if (l!=r){ lazy[i<<1] += lazy[i]; lazy[i<<1|1] += lazy[i]; } lazy[i]=0; } void update1(int i, int l, int r, int s, int e){ propagate(i, l, r); if (r<s || e<l) return; if (s<=l && r<=e){ lazy[i]++; propagate(i, l, r); return; } int m=(l+r)>>1; update1(i<<1, l, m, s, e); update1(i<<1|1, m+1, r, s, e); tree[i]=(tree[i<<1]+tree[i<<1|1])%MOD; } void update2(int i, int l, int r, int pos, int val){ //printf("upd2: %d %d %d %d %d\n", i, l, r, pos, val); propagate(i, l, r); if (pos<l || r<pos) return; if (l==r){ tree[i] = (tree[i]+val)%MOD; return; } int m=(l+r)>>1; update2(i<<1, l, m, pos, val); update2(i<<1|1, m+1, r, pos, val); tree[i]=(tree[i<<1]+tree[i<<1|1])%MOD; } ll query(int i, int l, int r, int s, int e){ propagate(i, l, r); if (r<s || e<l) return 0; if (s<=l && r<=e) return tree[i]; return (query(i<<1, l, (l+r)>>1, s, e)+query(i<<1|1, ((l+r)>>1)+1, r, s, e))%MOD; } int main(){ pw[0]=1; for (int i=1;i<100100;i++){ pw[i]=(pw[i-1]<<1)%MOD; } int n; scanf("%d", &n); for (int i=0;i<n;i++) scanf("%d %d", &xy[i].first, &xy[i].second); sort(xy, xy+n); ll ans=0; for (int i=0;i<n;i++){ auto p = xy[i]; ans = ((ans<<1) + query(1, 0, (n<<1)+1, 0, p.first-1)+1)%MOD; update1(1, 0, (n<<1)+1, p.second+1, (n<<1)+1); update2(1, 0, (n<<1)+1, p.second, query(1, 0, (n<<1)+1, 0, p.second)+1); //for (int j=0;j<=(n<<1)+1;j++) printf("%lld ", query(1, 0, (n<<1)+1, j, j)); } printf("%lld\n", ans); return 0; }
이제 원래 문제인 2<=K<=10인 경우를 풀어봅시다.
2번째 풀이와 동일하게 dp[i]를 정의하고, 점화식을 세워봅시다.
연결성분의 개수가 x에서 x+1이 될 때, 증가하는 값은 (x+1)^k-x^k = \Sum_{i=0}^{k-1} {kCi * x^i} 임을 알 수 있습니다. 이 값을 계산하기 위해 세그먼트 트리를 k개 만들고, i번째 세그먼트 트리에는 합집합의 원소 중 최댓값이 정확히 i인 부분집합의 (연결성분 개수)^i 값의 합을 저장합시다.
그러면 dp[i] = dp[i-1]*2 + \Sum_{j=0}^{k-1}{kCj * query_j(0, s_i-1)} 가 됨을 알 수 있습니다.
또한, 세그먼트 트리 갱신도 풀이 2와 비슷하게 처리할 수 있습니다. t번째 세그먼트 트리를 갱신하려고 하면 다음과 같이 해주면 됩니다.
끝지점이 [e_i+1, 2N]인 경우, i번째 선분을 추가했을 때 끝지점이 변하지 않으므로 [e_i+1, 2N]의 값을 전부 2배해줍니다.
끝지점이 [s_i, e_i]인 경우, i번째 선분을 추가했을 때 끝지점이 e_i가 되고, 연결성분의 개수가 변하지 않으므로, e_i번째 리프에 query_t(s_i, e_i)+1을 더해줍니다.
끝지점이 [0, s_i-1]인 경우, i번째 선분을 추가했을 때 끝지점이 e_i가 되고, 연결성분의 개수가 1 증가하므로, e_i번째 리프에 \Sum_{j=0}^{t}{tCj * query_j(0, s_i-1)}을 더해줍니다.
시간복잡도는 O(NK(logN+K))입니다.
구현:
더보기#include <bits/stdc++.h> #pragma GCC optimize("Ofast") #pragma GCC target("avx,avx2,fma") typedef long long ll; using namespace std; const int MOD = 1e9+7; pair<int, int> xy[100100]; int pw[100100], ncr[101][101]; struct Seg{ int tree[800800], lazy[800800]; void propagate(int i, int l, int r){ tree[i] = (ll)tree[i]*pw[lazy[i]]%MOD; if (l!=r) lazy[i<<1] += lazy[i], lazy[i<<1|1] += lazy[i]; lazy[i] = 0; } void update1(int i, int l, int r, int s, int e){ propagate(i, l, r); if (r<s || e<l) return; if (s<=l && r<=e){ lazy[i]++; propagate(i, l, r); return; } int m = (l+r)>>1; update1(i<<1, l, m, s, e); update1(i<<1|1, m+1, r, s, e); tree[i] = tree[i<<1]+tree[i<<1|1]; if (tree[i]>=MOD) tree[i] -= MOD; } void update2(int i, int l, int r, int p, int val){ propagate(i, l, r); if (p<l || r<p) return; if (l==r){ tree[i] += val; if (tree[i]>=MOD) tree[i] -= MOD; return; } int m = (l+r)>>1; update2(i<<1, l, m, p, val); update2(i<<1|1, m+1, r, p, val); tree[i] = tree[i<<1]+tree[i<<1|1]; if (tree[i]>=MOD) tree[i] -= MOD; } int query(int i, int l, int r, int s, int e){ propagate(i, l, r); if (r<s || e<l) return 0; if (s<=l && r<=e) return tree[i]; int m = (l+r)>>1; int tmp = query(i<<1, l, m, s, e)+query(i<<1|1, m+1, r, s, e); if (tmp>=MOD) return tmp-MOD; return tmp; } }tree[11]; int val[11]; int main(){ cin.tie(NULL); ios_base::sync_with_stdio(false); int n, k; cin >> n >> k; pw[0] = 1; for (int i=1;i<100100;i++){ pw[i] = pw[i-1]<<1; if (pw[i]>=MOD) pw[i] -= MOD; } ncr[0][0] = 1; for (int i=1;i<=k;i++){ ncr[i][i] = ncr[i][0] = 1; for (int j=1;j<i;j++){ ncr[i][j] = ncr[i-1][j-1]+ncr[i-1][j]; } } for (int i=0;i<n;i++) cin >> xy[i].first >> xy[i].second; sort(xy, xy+n); int MX = (n<<1)+1; ll ans = 0; for (int i=0;i<n;i++){ for (int j=0;j<k;j++) val[j] = tree[j].query(1, 0, MX, 0, xy[i].first-1); ans <<= 1; for (int j=0;j<k;j++) ans += (ll)val[j]*ncr[k][j]; ans++; ans %= MOD; //printf("%d %lld\n", val[0], ans); for (int j=0;j<k;j++){ tree[j].update1(1, 0, MX, xy[i].second+1, MX); ll tmp = tree[j].query(1, 0, MX, xy[i].first, xy[i].second)+1; for (int l=0;l<=j;l++) tmp += (ll)val[l]*ncr[j][l]; tmp %= MOD; tree[j].update2(1, 0, MX, xy[i].second, (int)tmp); } } cout << ans; return 0; }
풀이에서 이해가 안 되는 부분이 있다면 댓글 달아주세요.
'문제풀이' 카테고리의 다른 글
IOI21 Keys 풀이 (0) 2022.06.17 NYPC 2021 본선 1519부문 3번 풀이 (BOJ 23347) (1) 2021.11.01 백준 9208번 풀이 (링월드) (0) 2021.05.28 백준 8481번 풀이 (Generator) (1) 2021.04.29 백준 18799번 풀이 (이상한 편집기) (3) 2021.03.26