ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 백준 18779 풀이 (Help Yourself)
    문제풀이 2021. 6. 15. 14:02

    문제 링크: https://www.acmicpc.net/problem/18779

     

    18779번: Help Yourself (Platinum)

    Bessie has been given $N$ ($1\le N\le 10^5$) segments on a 1D number line. The $i$th segment contains all reals $x$ such that $l_i\le x\le r_i$. Define the union of a set of segments to be the set of all $x$ that are contained within at least one segment.

    www.acmicpc.net

     

    문제 설명: [1, 2N] 구간에 선분이 N개 있을 때, 2^N개의 모든 선분의 부분집합에 대해 (연결성분의 개수)^K의 합을 MOD로 나눈 나머지를 출력하는 문제입니다.

     

    풀이)

    일단 K=1인 경우부터 해결해봅시다. K=1인 경우는 다음 문제와 동일합니다.

    https://www.acmicpc.net/problem/18781

     

    18781번: Help Yourself (Gold)

    Bessie has been given $N$ segments ($1\le N\le 10^5$) on a 1D number line. The $i$th segment contains all reals $x$ such that $l_i\le x\le r_i$. Define the union of a set of segments to be the set of all $x$ that are contained within at least one segment.

    www.acmicpc.net

    저는 이 문제를 2가지 풀이로 풀었는데 모두 설명하도록 하겠습니다. 이 중 2번째 풀이를 확장해서 원래 문제를 풀 것입니다.

     

    풀이 1)

    K=1이므로 각 연결성분의 개수를 세주면 됩니다. [1, 2N] 구간에 있는 점 i에 대해, i에서 끝나는 연결성분을 가지는 부분집합의 개수의 합을 구하면 답이 됨을 알 수 있습니다. 이때, i를 포함하는 선분은 부분집합의 원소가 되어선 안 되고, 그렇지 않은 선분은 따로 고려할 필요가 없기 때문에, 각 선분의 끝점마다 2^{N-(i를 포함하는 선분 개수)}를 전부 더해주면 됩니다.

    선분의 시작점을 +1, 끝점을 -1로 놓고, 현재까지의 누적합을 계산하면 현재 있는 점이 선분의 끝점일 경우, 이 값은 i를 포함하는 선분 개수와 같아짐을 알 수 있습니다. 따라서, 선분의 끝점에 도달할 때마다 2^{N-(누적합)}을 답에 더해주면 문제가 풀립니다.

     

    시간복잡도는 O(N)입니다.

     

    구현:

    더보기
    #include <bits/stdc++.h>
    
    typedef long long ll;
    using namespace std;
    const int MOD = 1e9+7;
    int a[200200], pw[200200];
    
    int main(){
        int n;
        scanf("%d", &n);
        for (int i=0;i<n;i++){
            int x, y;
            scanf("%d %d", &x, &y);
            a[x]++; a[y]--;
        }
        pw[0] = 1;
        for (int i=1;i<=n;i++) pw[i] = (pw[i-1]<<1)%MOD;
        int cur = 0, ans = 0;
        for (int i=1;i<=(n<<1);i++){
            if (a[i]==-1){
                ans += pw[n-cur];
                if (ans>=MOD) ans -= MOD;
            }
            cur += a[i];
        }
        printf("%d\n", ans);
        return 0;
    }
    

     

    풀이 2)

    다이나믹 프로그래밍의 관점에서 접근해봅시다. 각 선분들을 시작점 기준으로 정렬해놓은 상태라고 가정합시다.

    i번째 선분의 시작지점과 끝점을 각각 s_i, e_i라고 합시다.

    그리고 dp[i] = (1번째 선분부터 i번째 선분까지만 존재할 때 답) 이라고 정의합시다.

     

    그러면 다음 점화식이 성립합니다.

    dp[i] = dp[i-1]*2 + (합집합의 원소 중 최댓값이 s_i보다 작은 부분집합의 개수)

     

    기존에 i-1번째 선분까지만 있을 때의 답은 dp[i-1]이고, 여기에 i번째 선분을 추가했을 때 연결성분의 개수는 0 또는 1 증가합니다. 1이 증가하는 경우는 선택한 부분집합의 원소들이 나타내는 영역의 원소 중 최댓값이 s_i보다 작은 경우뿐이므로, 위 점화식이 성립함을 알 수 있습니다.

     

    위 점화식을 계산하기 위해 세그먼트 트리를 하나 만듭시다. 세그먼트 트리의 i번째 리프에는 합집합의 원소 중 최댓값이 정확히 i인 부분집합의 개수를 저장하고, 각 노드에는 자식노드의 합을 저장합시다. 그러면 점화식을 O(logN)에 계산할 수 있고, 각 단계마다 세그먼트 트리를 업데이트해주면 됩니다.

     

    [0, e_i]구간에서 끝나는 부분집합에 i번째 선분을 추가할 경우 끝지점이 e_i가 되고, [e_i+1, 2N]구간에서 끝나는 부분집합에 i번째 선분을 추가하면 끝지점이 변하지 않습니다.

     

    따라서, 각 단계마다 e_i번째 리프의 값에 [0, e_i] 합+1을 더해주고(공집합에 추가하는 경우 때문에 +1), [e_i+1, 2N]번째 리프의 값을 전부 2배해주면 세그먼트 트리를 갱신할 수 있습니다.

    lazy propagation을 사용하면 O(NlogN)에 구현할 수 있습니다.

     

    구현:

    더보기
    #include <bits/stdc++.h>
    
    using namespace std;
    typedef long long ll;
    pair<int, int> xy[100100];
    ll tree[800400], lazy[800400], pw[100100], MOD=1e9+7;
    
    void propagate(int i, int l, int r){
        tree[i]=(tree[i]*pw[lazy[i]])%MOD;
        if (l!=r){
            lazy[i<<1] += lazy[i];
            lazy[i<<1|1] += lazy[i];
        }
        lazy[i]=0;
    }
    
    void update1(int i, int l, int r, int s, int e){
        propagate(i, l, r);
        if (r<s || e<l) return;
        if (s<=l && r<=e){
            lazy[i]++; propagate(i, l, r);
            return;
        }
        int m=(l+r)>>1;
        update1(i<<1, l, m, s, e); update1(i<<1|1, m+1, r, s, e);
        tree[i]=(tree[i<<1]+tree[i<<1|1])%MOD;
    }
    
    void update2(int i, int l, int r, int pos, int val){
        //printf("upd2: %d %d %d %d %d\n", i, l, r, pos, val);
        propagate(i, l, r);
        if (pos<l || r<pos) return;
        if (l==r){
            tree[i] = (tree[i]+val)%MOD;
            return;
        }
        int m=(l+r)>>1;
        update2(i<<1, l, m, pos, val); update2(i<<1|1, m+1, r, pos, val);
        tree[i]=(tree[i<<1]+tree[i<<1|1])%MOD;
    }
    
    ll query(int i, int l, int r, int s, int e){
        propagate(i, l, r);
        if (r<s || e<l) return 0;
        if (s<=l && r<=e) return tree[i];
        return (query(i<<1, l, (l+r)>>1, s, e)+query(i<<1|1, ((l+r)>>1)+1, r, s, e))%MOD;
    }
    
    int main(){
        pw[0]=1;
        for (int i=1;i<100100;i++){
            pw[i]=(pw[i-1]<<1)%MOD;
        }
        int n;
        scanf("%d", &n);
        for (int i=0;i<n;i++) scanf("%d %d", &xy[i].first, &xy[i].second);
        sort(xy, xy+n);
        ll ans=0;
        for (int i=0;i<n;i++){
            auto p = xy[i];
            ans = ((ans<<1) + query(1, 0, (n<<1)+1, 0, p.first-1)+1)%MOD;
            update1(1, 0, (n<<1)+1, p.second+1, (n<<1)+1);
            update2(1, 0, (n<<1)+1, p.second, query(1, 0, (n<<1)+1, 0, p.second)+1);
            //for (int j=0;j<=(n<<1)+1;j++) printf("%lld ", query(1, 0, (n<<1)+1, j, j));
        }
        printf("%lld\n", ans);
        return 0;
    }
    

     

    이제 원래 문제인 2<=K<=10인 경우를 풀어봅시다.

    2번째 풀이와 동일하게 dp[i]를 정의하고, 점화식을 세워봅시다.

    연결성분의 개수가 x에서 x+1이 될 때, 증가하는 값은 (x+1)^k-x^k = \Sum_{i=0}^{k-1} {kCi * x^i} 임을 알 수 있습니다. 이 값을 계산하기 위해 세그먼트 트리를 k개 만들고, i번째 세그먼트 트리에는 합집합의 원소 중 최댓값이 정확히 i인 부분집합의 (연결성분 개수)^i 값의 합을 저장합시다.

     

    그러면 dp[i] = dp[i-1]*2 + \Sum_{j=0}^{k-1}{kCj * query_j(0, s_i-1)} 가 됨을 알 수 있습니다. 

     

    또한, 세그먼트 트리 갱신도 풀이 2와 비슷하게 처리할 수 있습니다. t번째 세그먼트 트리를 갱신하려고 하면 다음과 같이 해주면 됩니다.

    끝지점이 [e_i+1, 2N]인 경우, i번째 선분을 추가했을 때 끝지점이 변하지 않으므로 [e_i+1, 2N]의 값을 전부 2배해줍니다.

    끝지점이 [s_i, e_i]인 경우, i번째 선분을 추가했을 때 끝지점이 e_i가 되고, 연결성분의 개수가 변하지 않으므로, e_i번째 리프에 query_t(s_i, e_i)+1을 더해줍니다.

    끝지점이 [0, s_i-1]인 경우, i번째 선분을 추가했을 때 끝지점이 e_i가 되고, 연결성분의 개수가 1 증가하므로, e_i번째 리프에 \Sum_{j=0}^{t}{tCj * query_j(0, s_i-1)}을 더해줍니다.

     

    시간복잡도는 O(NK(logN+K))입니다.

     

    구현:

    더보기
    #include <bits/stdc++.h>
    #pragma GCC optimize("Ofast")
    #pragma GCC target("avx,avx2,fma")
    
    typedef long long ll;
    using namespace std;
    const int MOD = 1e9+7;
    pair<int, int> xy[100100];
    int pw[100100], ncr[101][101];
    struct Seg{
        int tree[800800], lazy[800800];
        void propagate(int i, int l, int r){
            tree[i] = (ll)tree[i]*pw[lazy[i]]%MOD;
            if (l!=r) lazy[i<<1] += lazy[i], lazy[i<<1|1] += lazy[i];
            lazy[i] = 0;
        }
        void update1(int i, int l, int r, int s, int e){
            propagate(i, l, r);
            if (r<s || e<l) return;
            if (s<=l && r<=e){
                lazy[i]++; propagate(i, l, r);
                return;
            }
            int m = (l+r)>>1;
            update1(i<<1, l, m, s, e); update1(i<<1|1, m+1, r, s, e);
            tree[i] = tree[i<<1]+tree[i<<1|1];
            if (tree[i]>=MOD) tree[i] -= MOD;
        }
        void update2(int i, int l, int r, int p, int val){
            propagate(i, l, r);
            if (p<l || r<p) return;
            if (l==r){
                tree[i] += val;
                if (tree[i]>=MOD) tree[i] -= MOD;
                return;
            }
            int m = (l+r)>>1;
            update2(i<<1, l, m, p, val); update2(i<<1|1, m+1, r, p, val);
            tree[i] = tree[i<<1]+tree[i<<1|1];
            if (tree[i]>=MOD) tree[i] -= MOD;
        }
        int query(int i, int l, int r, int s, int e){
            propagate(i, l, r);
            if (r<s || e<l) return 0;
            if (s<=l && r<=e) return tree[i];
            int m = (l+r)>>1;
            int tmp = query(i<<1, l, m, s, e)+query(i<<1|1, m+1, r, s, e);
            if (tmp>=MOD) return tmp-MOD;
            return tmp;
        }
    }tree[11];
    int val[11];
    
    int main(){
        cin.tie(NULL);
        ios_base::sync_with_stdio(false);
        int n, k;
        cin >> n >> k;
        pw[0] = 1;
        for (int i=1;i<100100;i++){
            pw[i] = pw[i-1]<<1;
            if (pw[i]>=MOD) pw[i] -= MOD;
        }
        ncr[0][0] = 1;
        for (int i=1;i<=k;i++){
            ncr[i][i] = ncr[i][0] = 1;
            for (int j=1;j<i;j++){
                ncr[i][j] = ncr[i-1][j-1]+ncr[i-1][j];
            }
        }
        for (int i=0;i<n;i++) cin >> xy[i].first >> xy[i].second;
        sort(xy, xy+n);
        int MX = (n<<1)+1;
        ll ans = 0;
        for (int i=0;i<n;i++){
            for (int j=0;j<k;j++) val[j] = tree[j].query(1, 0, MX, 0, xy[i].first-1);
            ans <<= 1;
            for (int j=0;j<k;j++) ans += (ll)val[j]*ncr[k][j];
            ans++;
            ans %= MOD;
            //printf("%d %lld\n", val[0], ans);
    
            for (int j=0;j<k;j++){
                tree[j].update1(1, 0, MX, xy[i].second+1, MX);
                ll tmp = tree[j].query(1, 0, MX, xy[i].first, xy[i].second)+1;
                for (int l=0;l<=j;l++) tmp += (ll)val[l]*ncr[j][l];
                tmp %= MOD;
                tree[j].update2(1, 0, MX, xy[i].second, (int)tmp);
            }
        }
        cout << ans;
        return 0;
    }
    

     

    풀이에서 이해가 안 되는 부분이 있다면 댓글 달아주세요.

    댓글

Designed by Tistory.