-
백준 18779 풀이 (Help Yourself)문제풀이 2021. 6. 15. 14:02
문제 링크: https://www.acmicpc.net/problem/18779
문제 설명: [1, 2N] 구간에 선분이 N개 있을 때, 2^N개의 모든 선분의 부분집합에 대해 (연결성분의 개수)^K의 합을 MOD로 나눈 나머지를 출력하는 문제입니다.
풀이)
일단 K=1인 경우부터 해결해봅시다. K=1인 경우는 다음 문제와 동일합니다.
https://www.acmicpc.net/problem/18781
저는 이 문제를 2가지 풀이로 풀었는데 모두 설명하도록 하겠습니다. 이 중 2번째 풀이를 확장해서 원래 문제를 풀 것입니다.
풀이 1)
K=1이므로 각 연결성분의 개수를 세주면 됩니다. [1, 2N] 구간에 있는 점 i에 대해, i에서 끝나는 연결성분을 가지는 부분집합의 개수의 합을 구하면 답이 됨을 알 수 있습니다. 이때, i를 포함하는 선분은 부분집합의 원소가 되어선 안 되고, 그렇지 않은 선분은 따로 고려할 필요가 없기 때문에, 각 선분의 끝점마다 2^{N-(i를 포함하는 선분 개수)}를 전부 더해주면 됩니다.
선분의 시작점을 +1, 끝점을 -1로 놓고, 현재까지의 누적합을 계산하면 현재 있는 점이 선분의 끝점일 경우, 이 값은 i를 포함하는 선분 개수와 같아짐을 알 수 있습니다. 따라서, 선분의 끝점에 도달할 때마다 2^{N-(누적합)}을 답에 더해주면 문제가 풀립니다.
시간복잡도는 O(N)입니다.
구현:
더보기#include <bits/stdc++.h> typedef long long ll; using namespace std; const int MOD = 1e9+7; int a[200200], pw[200200]; int main(){ int n; scanf("%d", &n); for (int i=0;i<n;i++){ int x, y; scanf("%d %d", &x, &y); a[x]++; a[y]--; } pw[0] = 1; for (int i=1;i<=n;i++) pw[i] = (pw[i-1]<<1)%MOD; int cur = 0, ans = 0; for (int i=1;i<=(n<<1);i++){ if (a[i]==-1){ ans += pw[n-cur]; if (ans>=MOD) ans -= MOD; } cur += a[i]; } printf("%d\n", ans); return 0; }
풀이 2)
다이나믹 프로그래밍의 관점에서 접근해봅시다. 각 선분들을 시작점 기준으로 정렬해놓은 상태라고 가정합시다.
i번째 선분의 시작지점과 끝점을 각각 s_i, e_i라고 합시다.
그리고 dp[i] = (1번째 선분부터 i번째 선분까지만 존재할 때 답) 이라고 정의합시다.
그러면 다음 점화식이 성립합니다.
dp[i] = dp[i-1]*2 + (합집합의 원소 중 최댓값이 s_i보다 작은 부분집합의 개수)
기존에 i-1번째 선분까지만 있을 때의 답은 dp[i-1]이고, 여기에 i번째 선분을 추가했을 때 연결성분의 개수는 0 또는 1 증가합니다. 1이 증가하는 경우는 선택한 부분집합의 원소들이 나타내는 영역의 원소 중 최댓값이 s_i보다 작은 경우뿐이므로, 위 점화식이 성립함을 알 수 있습니다.
위 점화식을 계산하기 위해 세그먼트 트리를 하나 만듭시다. 세그먼트 트리의 i번째 리프에는 합집합의 원소 중 최댓값이 정확히 i인 부분집합의 개수를 저장하고, 각 노드에는 자식노드의 합을 저장합시다. 그러면 점화식을 O(logN)에 계산할 수 있고, 각 단계마다 세그먼트 트리를 업데이트해주면 됩니다.
[0, e_i]구간에서 끝나는 부분집합에 i번째 선분을 추가할 경우 끝지점이 e_i가 되고, [e_i+1, 2N]구간에서 끝나는 부분집합에 i번째 선분을 추가하면 끝지점이 변하지 않습니다.
따라서, 각 단계마다 e_i번째 리프의 값에 [0, e_i] 합+1을 더해주고(공집합에 추가하는 경우 때문에 +1), [e_i+1, 2N]번째 리프의 값을 전부 2배해주면 세그먼트 트리를 갱신할 수 있습니다.
lazy propagation을 사용하면 O(NlogN)에 구현할 수 있습니다.
구현:
더보기#include <bits/stdc++.h> using namespace std; typedef long long ll; pair<int, int> xy[100100]; ll tree[800400], lazy[800400], pw[100100], MOD=1e9+7; void propagate(int i, int l, int r){ tree[i]=(tree[i]*pw[lazy[i]])%MOD; if (l!=r){ lazy[i<<1] += lazy[i]; lazy[i<<1|1] += lazy[i]; } lazy[i]=0; } void update1(int i, int l, int r, int s, int e){ propagate(i, l, r); if (r<s || e<l) return; if (s<=l && r<=e){ lazy[i]++; propagate(i, l, r); return; } int m=(l+r)>>1; update1(i<<1, l, m, s, e); update1(i<<1|1, m+1, r, s, e); tree[i]=(tree[i<<1]+tree[i<<1|1])%MOD; } void update2(int i, int l, int r, int pos, int val){ //printf("upd2: %d %d %d %d %d\n", i, l, r, pos, val); propagate(i, l, r); if (pos<l || r<pos) return; if (l==r){ tree[i] = (tree[i]+val)%MOD; return; } int m=(l+r)>>1; update2(i<<1, l, m, pos, val); update2(i<<1|1, m+1, r, pos, val); tree[i]=(tree[i<<1]+tree[i<<1|1])%MOD; } ll query(int i, int l, int r, int s, int e){ propagate(i, l, r); if (r<s || e<l) return 0; if (s<=l && r<=e) return tree[i]; return (query(i<<1, l, (l+r)>>1, s, e)+query(i<<1|1, ((l+r)>>1)+1, r, s, e))%MOD; } int main(){ pw[0]=1; for (int i=1;i<100100;i++){ pw[i]=(pw[i-1]<<1)%MOD; } int n; scanf("%d", &n); for (int i=0;i<n;i++) scanf("%d %d", &xy[i].first, &xy[i].second); sort(xy, xy+n); ll ans=0; for (int i=0;i<n;i++){ auto p = xy[i]; ans = ((ans<<1) + query(1, 0, (n<<1)+1, 0, p.first-1)+1)%MOD; update1(1, 0, (n<<1)+1, p.second+1, (n<<1)+1); update2(1, 0, (n<<1)+1, p.second, query(1, 0, (n<<1)+1, 0, p.second)+1); //for (int j=0;j<=(n<<1)+1;j++) printf("%lld ", query(1, 0, (n<<1)+1, j, j)); } printf("%lld\n", ans); return 0; }
이제 원래 문제인 2<=K<=10인 경우를 풀어봅시다.
2번째 풀이와 동일하게 dp[i]를 정의하고, 점화식을 세워봅시다.
연결성분의 개수가 x에서 x+1이 될 때, 증가하는 값은 (x+1)^k-x^k = \Sum_{i=0}^{k-1} {kCi * x^i} 임을 알 수 있습니다. 이 값을 계산하기 위해 세그먼트 트리를 k개 만들고, i번째 세그먼트 트리에는 합집합의 원소 중 최댓값이 정확히 i인 부분집합의 (연결성분 개수)^i 값의 합을 저장합시다.
그러면 dp[i] = dp[i-1]*2 + \Sum_{j=0}^{k-1}{kCj * query_j(0, s_i-1)} 가 됨을 알 수 있습니다.
또한, 세그먼트 트리 갱신도 풀이 2와 비슷하게 처리할 수 있습니다. t번째 세그먼트 트리를 갱신하려고 하면 다음과 같이 해주면 됩니다.
끝지점이 [e_i+1, 2N]인 경우, i번째 선분을 추가했을 때 끝지점이 변하지 않으므로 [e_i+1, 2N]의 값을 전부 2배해줍니다.
끝지점이 [s_i, e_i]인 경우, i번째 선분을 추가했을 때 끝지점이 e_i가 되고, 연결성분의 개수가 변하지 않으므로, e_i번째 리프에 query_t(s_i, e_i)+1을 더해줍니다.
끝지점이 [0, s_i-1]인 경우, i번째 선분을 추가했을 때 끝지점이 e_i가 되고, 연결성분의 개수가 1 증가하므로, e_i번째 리프에 \Sum_{j=0}^{t}{tCj * query_j(0, s_i-1)}을 더해줍니다.
시간복잡도는 O(NK(logN+K))입니다.
구현:
더보기#include <bits/stdc++.h> #pragma GCC optimize("Ofast") #pragma GCC target("avx,avx2,fma") typedef long long ll; using namespace std; const int MOD = 1e9+7; pair<int, int> xy[100100]; int pw[100100], ncr[101][101]; struct Seg{ int tree[800800], lazy[800800]; void propagate(int i, int l, int r){ tree[i] = (ll)tree[i]*pw[lazy[i]]%MOD; if (l!=r) lazy[i<<1] += lazy[i], lazy[i<<1|1] += lazy[i]; lazy[i] = 0; } void update1(int i, int l, int r, int s, int e){ propagate(i, l, r); if (r<s || e<l) return; if (s<=l && r<=e){ lazy[i]++; propagate(i, l, r); return; } int m = (l+r)>>1; update1(i<<1, l, m, s, e); update1(i<<1|1, m+1, r, s, e); tree[i] = tree[i<<1]+tree[i<<1|1]; if (tree[i]>=MOD) tree[i] -= MOD; } void update2(int i, int l, int r, int p, int val){ propagate(i, l, r); if (p<l || r<p) return; if (l==r){ tree[i] += val; if (tree[i]>=MOD) tree[i] -= MOD; return; } int m = (l+r)>>1; update2(i<<1, l, m, p, val); update2(i<<1|1, m+1, r, p, val); tree[i] = tree[i<<1]+tree[i<<1|1]; if (tree[i]>=MOD) tree[i] -= MOD; } int query(int i, int l, int r, int s, int e){ propagate(i, l, r); if (r<s || e<l) return 0; if (s<=l && r<=e) return tree[i]; int m = (l+r)>>1; int tmp = query(i<<1, l, m, s, e)+query(i<<1|1, m+1, r, s, e); if (tmp>=MOD) return tmp-MOD; return tmp; } }tree[11]; int val[11]; int main(){ cin.tie(NULL); ios_base::sync_with_stdio(false); int n, k; cin >> n >> k; pw[0] = 1; for (int i=1;i<100100;i++){ pw[i] = pw[i-1]<<1; if (pw[i]>=MOD) pw[i] -= MOD; } ncr[0][0] = 1; for (int i=1;i<=k;i++){ ncr[i][i] = ncr[i][0] = 1; for (int j=1;j<i;j++){ ncr[i][j] = ncr[i-1][j-1]+ncr[i-1][j]; } } for (int i=0;i<n;i++) cin >> xy[i].first >> xy[i].second; sort(xy, xy+n); int MX = (n<<1)+1; ll ans = 0; for (int i=0;i<n;i++){ for (int j=0;j<k;j++) val[j] = tree[j].query(1, 0, MX, 0, xy[i].first-1); ans <<= 1; for (int j=0;j<k;j++) ans += (ll)val[j]*ncr[k][j]; ans++; ans %= MOD; //printf("%d %lld\n", val[0], ans); for (int j=0;j<k;j++){ tree[j].update1(1, 0, MX, xy[i].second+1, MX); ll tmp = tree[j].query(1, 0, MX, xy[i].first, xy[i].second)+1; for (int l=0;l<=j;l++) tmp += (ll)val[l]*ncr[j][l]; tmp %= MOD; tree[j].update2(1, 0, MX, xy[i].second, (int)tmp); } } cout << ans; return 0; }
풀이에서 이해가 안 되는 부분이 있다면 댓글 달아주세요.
'문제풀이' 카테고리의 다른 글
IOI21 Keys 풀이 (0) 2022.06.17 NYPC 2021 본선 1519부문 3번 풀이 (BOJ 23347) (1) 2021.11.01 백준 9208번 풀이 (링월드) (0) 2021.05.28 백준 8481번 풀이 (Generator) (1) 2021.04.29 PS 일지 (4/7~4/19) (0) 2021.04.19