ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 백준 20535번 풀이 (Good bye, BOJ 2020! G번)
    문제풀이 2021. 1. 3. 20:45

    굿바이 boj 2020 대회에 참가하지 못해서 혼자 문제를 풀어보던 중 G번을 접하게 되었습니다. 문제를 보고 풀이가 금방 떠올랐는데 제 풀이가 알려져 있는 풀이들과는 다른 방식이여서 블로그에 올리게 되었습니다. (이 풀이는 정해보다 복잡한 방법이기 때문에 이렇게 푸는건 추천하지 않습니다. 그냥 이렇게도 풀 수 있구나 하고 재미로 봐주세요.)

     

    문제에 대해 간략하게 설명하자면, 쿼리마다 몇개의 점 v_1, v_2, ... , v_k가 들어오는데 가능한 모든 점의 순서쌍 (v_i, v_j)에 대해 lca의 깊이의 합을 계산해서 출력하는 문제입니다. (단, i<j)

     

    일단, 이 문제의 정해는 다음과 같이 3가지가 있습니다.

    1. 오프라인 쿼리

    2. 트리 압축

    3. O((N+X)sqrt(X))로 끼워맞춰서 통과하는 풀이 (X는 모든 쿼리의 K값들의 합입니다.)

    이 풀이들에 대한 자세한 설명은 생략하도록 하겠습니다.

     

    제가 여기서 설명해드릴 방법은 트리 압축 풀이와 되게 유사한 풀이입니다. 다른 점은 저는 트리 압축을 통해 쿼리의 답을 구하지 않고, 세그먼트 트리를 이용하여 쿼리를 처리했다는 점입니다. (시간복잡도 O((N+X)logN))

     

     

    풀이)

    일단 사전지식으로 O(1) LCA 알고리즘에 대해 알고 있어야 합니다. 이 방식에 대해서는 아래 링크를 참고하시기 바랍니다.

    www.secmem.org/blog/2019/03/27/fast-LCA-with-sparsetable/

     

    O(1) LCA Algorithm (with Sparse Table)

    목표 및 문제 소개 LCA(Lowest Common Ancestor)란 루트가 정해진 트리에서, 두 노드 간의 공통 조상이면서 루트에서 가장 먼 노드를 뜻합니다. 노드가 \(N\)개인 트리에서 임의의 두 노드 간의 LCA를 쿼리

    www.secmem.org

    이 알고리즘을 응용하여 쿼리를 처리하는 방법을 생각해볼 것입니다.

    일단, 이 알고리즘에서 핵심적인 아이디어는 어떤 점 u, v가 있을 때, 오일러 투어 상에서 점 u와 v 사이에 존재하는 모든 점들의 깊이 중 최솟값이 lca의 깊이가 된다는 아이디어입니다.

     

    더 정확한 설명을 위해 몇 가지 배열들을 정의해보도록 하겠습니다.

    euler[i]: 오일러 투어에서 i번째 시점에 방문하는 점 v를 저장한 배열

    idx[v]: 점 v가 오일러 투어에서 처음으로 나오는 인덱스를 저장한 배열

    depth[v]: 점 v의 깊이를 저장한 배열

     

    다시 설명하면, 어떤 점 u, v에 대해 u와 v의 lca의 깊이는 구간 [euler[idx[u]], euler[idx[v]]]에 존재하는 모든 점 w에 대해 depth[w]의 최솟값과 동일합니다.

     

    이 사실을 바탕으로 쿼리를 처리하는 방법에 대해 살펴봅시다.

    일단, 쿼리로 들어온 k개의 점들을 idx[v_i] 값 기준으로 오름차순 정렬합시다. 그리고, 인접한 점 v_i, v_{i+1}에 대해, 두 점의 lca를 lca[i]라고 정의합시다. 이렇게 했을 때, 위에서 살펴본 특징을 이용하면, 두 점 v_i와 v_j의 lca의 깊이는 min(lca[i], lca[i+1], ... , lca[j-1])와 동일합니다.

     

    이제, 쿼리를 처리해보도록 합시다.

     

    기본적인 접근 방식은 dp입니다. 현재 위치가 cur이고, cur이하인 모든 i, j에 대해 v_i와 v_j의 lca의 깊이를 ans에 더해준 상태라고 가정합시다. 또한, cur보다 작은 모든 i에 대해, v_i와 v_cur의 lca의 깊이에 대한 정보를 적당한 형태로 저장해놓았다고 가정합시다. 이제 cur에 1을 더한다음, ans와 lca의 깊이 정보를 갱신해봅시다.

     

    여기서 세그먼트 트리를 이용하어 ans와 lca의 깊이 정보들을 갱신할 것입니다.

    세그먼트 트리에 다음과 같은 값들을 저장해봅시다.

    구간 [l, r]에 대해

    1. lca(v_i, v_{cur-1})들 중 값이 l이상, r이하인 점 v_i의 개수

    2. lca(v_i, v_{cur-1})들 중 값이 l이상, r이하인 모든 점 v_i에 대한 lca(v_i, v_{cur-1})의 합

    (cur이 현재보다 1 작을 때 마지막으로 갱신되었기 때문에 현재 상태에서는 cur-1에 대해 저장되어있는 상태입니다.)

     

    (v_{cur-1}, v_{cur})의 값을 cur_lca라고 한다면, 위에서 얻은 특징들을 통해 ans에 다음과 같은 값을 더해주면 된다는 것을 알 수 있습니다.

    1. cur_lca*(구간 [cur_lca, INF]에 있는 점의 개수+1)

    2. 구간 [-1, cur_lca-1]에 있는 lca 값들의 합

    그 이유는 lca(v_i, v_cur)=min(lca(v_i, v_{cur-1}), cur_lca)이므로 lca(v_i, v_{cur-1}의 값이 cur_lca이상이면 lca의 깊이가 cur_lca가 되고, 그렇지 않으면 lca(v_i, v_{cur-1})과 똑같기 때문입니다.

     

    그러면 이제 세그먼트 트리만 갱신을 해주면 쿼리를 처리할 수 있습니다.

    세그먼트 트리의 갱신은 간단합니다. 세그먼트 트리에 저장되어 있는 lca 값들 중, cur_lca보다 큰 값들은 모두 cur_lca의 값으로 바뀌기 때문에, cur_lca에 저장된 점 개수를 구간 [cur_lca, INF]에 저장되어 있는 점의 개수에 1을 더한 만큼으로 갱신을 해주고, 구간 [cur_lca+1, INF]에 저장된 점을 모두 제거해주면 됩니다. (이는 lazy propagation을 이용하면 간단히 처리할 수 있습니다.)

     

    쿼리를 처리하는 부분의 코드는 다음과 같습니다.

    void solve(){
        ver.clear(); //ver: 쿼리로 들어온 점의 idx값과 정점 번호를 저장한 벡터
        int k;
        scanf("%d", &k);
        for (int i=0;i<k;i++){
            int tmp;
            scanf("%d", &tmp);
            ver.push_back(make_pair(idx[tmp], tmp));
        }
        sort(ver.begin(), ver.end());
        ll ans=0;
        //ans와 세그먼트 트리를 갱신하는 부분
        //세그먼트 트리에서 lcatmp값이 0일 경우 0미만의 범위 값을 불러오게 돼서 오류가 발생하기 때문에 1씩 더해서 저장
        for (int i=1;i<k;i++){
            int lcatmp=LCA(ver[i-1].second, ver[i].second);
            ans += lcatmp;
            pair<ll, int> tmp1=query(1, 0, MAXN-1, 0, lcatmp), tmp2=query(1, 0, MAXN-1, lcatmp+1, MAXN-1);
            ans += tmp1.first;
            ans += (ll)lcatmp*tmp2.second;
            update(1, 0, MAXN-1, lcatmp+1, lcatmp+1, tmp2.second+1);
            update(1, 0, MAXN-1, lcatmp+2, MAXN-1, 0);
        }
        update(1, 0, MAXN-1, 0, MAXN-1, 0);
        printf("%lld\n", ans);
    }

     

    따라서, 점 K개가 들어왔을 때 쿼리를 O(Klog(max_depth))만에 처리하는 것이 가능하고, 최대 깊이는 O(N)이므로 쿼리를 O(KlogN)에 처리하는 것이 가능합니다. 따라서, O((N+X)logN)에 문제를 푸는 것이 가능합니다. (O(NlogN)은 O(1) lca 알고리즘의 전처리과정에 의해 생긴 것입니다.)

     

    전체 코드는 다음과 같습니다.

    #include <bits/stdc++.h>
    
    typedef long long ll;
    using namespace std;
    const int MAXN=500100, LOGN=21;
    int timer=0, n;
    int euler[MAXN*2], lev[MAXN], pw2[LOGN], lg2[MAXN*2], idx[MAXN];
    pair<int, int> st[LOGN][MAXN*2];
    vector<int> adj[MAXN];
    vector<pair<int, int>> ver;
    pair<ll, int> tree[MAXN*4+10];
    bool lazy[MAXN*4+10];
    
    void dfs(int s, int pa, int l){ //오일러 투어
        lev[s]=l;
        euler[++timer]=s;
        if (!idx[s]) idx[s]=timer;
        for (int v:adj[s]) if(v!=pa){
            dfs(v, s, l+1);
            euler[++timer]=s;
        }
    }
    
    void st_build(){ //O(1) lca를 위한 sparse table build
        pw2[0]=1;
        for (int i=1;i<LOGN;i++) pw2[i]=2*pw2[i-1];
        memset(lg2, -1, sizeof lg2);
        for (int i=0;i<LOGN;i++) if(pw2[i]<MAXN*2) lg2[pw2[i]]=i;
        for (int i=2;i<MAXN*2;i++) if(lg2[i]==-1) lg2[i]=lg2[i-1];
    
        for (int j=1;j<=2*n-1;j++) st[0][j]=make_pair(lev[euler[j]], euler[j]);
        for (int i=1;i<LOGN;i++){
            for (int j=1;j<=2*n-1;j++){
                if (j+pw2[i-1]>=2*n) break;
                st[i][j]=min(st[i-1][j], st[i-1][j+pw2[i-1]]);
            }
        }
    }
    
    int LCA(int u, int v){ //O(1) lca
        int l=idx[u], r=idx[v];
        if (l>r) swap(l, r);
        int k=lg2[r-l+1];
        return min(st[k][l], st[k][r-pw2[k]+1]).first; //return level
    }
    
    //세그먼트 트리 구현
    void propagate(int i, int l, int r){
        if (lazy[i]){
            tree[i]=make_pair(0, 0);
        }
        if (lazy[i] && l!=r) lazy[i*2]=1, lazy[i*2+1]=1;
        lazy[i]=0;
    }
    
    void update(int i, int l, int r, int s, int e, int val){
        propagate(i, l, r);
        if(r<s || e<l) return;
        if (s==e && s==l && s==r){
            tree[i].second = val;
            tree[i].first = (ll)(s-1)*val;
            return;
        }
        if(s<=l && r<=e){
            lazy[i]=1; propagate(i, l, r);
            return;
        }
        int m = (l+r) / 2;
        update(i*2, l, m, s, e, val), update(i*2+1, m+1, r, s, e, val);
        tree[i].first = tree[i*2].first + tree[i*2+1].first;
        tree[i].second = tree[i*2].second + tree[i*2+1].second;
    }
    
    pair<ll, int> query(int i, int l, int r, int s, int e){
        propagate(i, l, r);
        if(r<s || e<l) return make_pair(0, 0);
        if(s<=l && r<=e){
            return tree[i];
        }
        int m = (l+r) / 2;
        pair<ll, int> ret1=query(i*2, l, m, s, e), ret2=query(i*2+1, m+1, r, s, e);
        return make_pair(ret1.first+ret2.first, ret1.second+ret2.second);
    }
    
    void solve(){ //쿼리 처리
        ver.clear();
        int k;
        scanf("%d", &k);
        for (int i=0;i<k;i++){
            int tmp;
            scanf("%d", &tmp);
            ver.push_back(make_pair(idx[tmp], tmp));
        }
        sort(ver.begin(), ver.end());
        ll ans=0;
        for (int i=1;i<k;i++){
            int lcatmp=LCA(ver[i-1].second, ver[i].second);
            ans += lcatmp;
            pair<ll, int> tmp1=query(1, 0, MAXN-1, 0, lcatmp), tmp2=query(1, 0, MAXN-1, lcatmp+1, MAXN-1);
            ans += tmp1.first;
            ans += (ll)lcatmp*tmp2.second;
            update(1, 0, MAXN-1, lcatmp+1, lcatmp+1, tmp2.second+1);
            update(1, 0, MAXN-1, lcatmp+2, MAXN-1, 0);
        }
        update(1, 0, MAXN-1, 0, MAXN-1, 0);
        printf("%lld\n", ans);
    }
    
    int main(){
        int q;
        scanf("%d %d", &n, &q);
        for (int i=2;i<=n;i++){
            int b;
            scanf("%d", &b);
            adj[b].push_back(i);
            adj[i].push_back(b);
        }
        dfs(1, -1, 0);
        st_build();
        while(q--) solve();
        return 0;
    }

     

    댓글

Designed by Tistory.