ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 백준 13515번 풀이 (트리와 쿼리 6)
    문제풀이 2021. 3. 17. 17:08

    문제 링크: www.acmicpc.net/problem/13515

     

    13515번: 트리와 쿼리 6

    N개의 정점으로 이루어진 트리(무방향 사이클이 없는 연결 그래프)가 있다. 정점은 1번부터 N번까지 번호가 매겨져 있고, 간선은 1번부터 N-1번까지 번호가 매겨져 있다. 가장 처음에 모든 정점의

    www.acmicpc.net

    다음 2가지 쿼리를 처리하는 문제입니다.

     

    1 i. i번 정점의 색을 바꾼다.

    2 u. u와 연결된 정점의 개수를 출력한다. (단, 두 정점이 연결되었다는 것은 두 정점을 연결하는 경로상의 모든 정점의 색이 같다는 것을 의미한다.)

     

    이 문제의 정식 풀이는 Tree DP를 heavy-light decomposition으로 최적화하는 것이고, sqrt decomposition을 통해 오프라인 쿼리로 처리하는 풀이도 존재합니다.

     

    이 글에서는 centroid decomposition을 이용한 풀이를 설명하도록 하겠습니다.

     

    풀이)

    일단 centroid tree를 만든 후, 2번 쿼리를 어떻게 처리할지에 대해 생각해봅시다.

    centroid tree에서 정점 u부터 부모를 따라 올라가고, 현재 점 v에 있다고 해봅시다. 2번 쿼리를 처리하는 방법은 u와 v가 연결되어 있는지 확인한 후, centroid가 v인 서브트리 상에서 v와 연결된 점의 개수를 세주면 됩니다.

     

    1-1. u와 v의 연결성 확인

     

    색이 업데이트 될 때, 경로 상에 같은 색의 점만 존재하는지 판별하면 됩니다. 이는 간단한 heavy-light decomposition을 통해 O(log^2N)에 해결 가능합니다.

     

    1-2. centroid가 v인 서브트리 상에서 v와 연결된 점의 개수 세기

     

    이것만 잘 처리한다면 2번 쿼리에 대한 답을 할 수 있습니다.

    두 정점 x와 y가 연결되어 있다는 것과 동치인 명제는 "경로 x-y 상에 검은색 정점의 개수가 0개이거나, 흰색 정점의 개수가 0개이다" 라는 것을 쉽게 확인할 수 있습니다.

    따라서, 센트로이드가 v인 서브트리에 속하는 점 x에 대해, 경로 v-x 상에 있는 검은색 정점의 개수와 흰색 정점의 개수를 저장한 후, 상황에 따라 검은색 또는 흰색 정점의 개수가 0개인 경로의 개수를 세주면 됩니다. 

     

    이 작업을 하기 위해 세그먼트 트리에 다음과 같은 정보를 저장합시다. (서브트리 쿼리를 처리해야 하기 때문에 점들을 dfs order로 번호를 붙여줍시다.)

    1. v에서 [l, r]에 속한 점으로 가는 경로 중 경로 상에 있는 검은색 정점의 최소 개수와 그 최솟값을 갖는 경로의 개수

    2. 1에서 검은색 대신 하얀색에 대해 계산한 값

     

    이제, 점 u에 대해 v를 거치는 경로에 대해 값을 계산하려면 v=u인 경우 세그먼트 트리의 전체 구간에 대해 경로의 개수를 세주면 됩니다. v!=u인 경우 세그먼트 트리의 전체 구간에 대해 경로의 개수를 세준 후 u가 속한 서브트리를 나타내는 구간에서의 경로의 개수를 빼주면 됩니다. 빼주는 부분에 속한 점들까지의 경로는 v를 거치지 않고 도달하기 때문입니다.

     

    정확히 말하자면 트리 상에서 u의 부모를 따라 v까지 올라갈 때, v 직전에 만난 정점이 루트인 서브트리에 대한 쿼리로 얻은 경로의 개수를 빼주면 됩니다.

    시간복잡도는 O(logN)입니다.

     

    따라서, 1, 2번 작업을 centroid tree를 따라 올라가면서 해주면 되고, 2번 쿼리의 시간복잡도는 O(log^3N)이 됨을 확인할 수 있습니다. (시간초과가 날 수 있지 않냐는 질문을 할 수 있는데 이에 대한 설명이 뒤에 있습니다.)

     

    2. 1번 쿼리 처리

     

    1, 2번에서 사용한 세그먼트 트리들을 갱신해주기만 하면 문제가 해결됩니다.

    1번은 간단한 hld이므로, 업데이트를 O(logN)에 할 수 있습니다.

     

    2번의 경우 centroid tree를 따라 올라가면서 업데이트를 해줘야 합니다. 점 u를 업데이트 해야 하고, 현재 센트로이드가 v인 서브트리에 있다고 가정합시다. u의 색이 변하게 되면, 센트로이드가 v인 서브트리 상에서 u가 루트인 서브트리의 원소들의 값만 1씩 변하게 되고, 나머지 값들은 영향을 받지 않습니다. 따라서, 점 u가 속한 서브트리의 값들만 세그먼트 트리 상에서 업데이트를 해주면 되고, 점들을 dfs order로 번호를 붙였기 때문에 lazy propagation을 통해 O(logN)에 업데이트가 가능하고, 각 센트로이드에 대해 업데이트를 해주면 O(log^2N)에 업데이트가 가능합니다.

     

    따라서, 1번 쿼리의 총 시간복잡도는 O(log^2N)이 됩니다.

     

    3. 시간초과가 날 수도 있지 않나요?

     

    위와 같은 방식으로 쿼리를 처리하면 총 시간복잡도는 O(NlogN+Qlog^3N)임을 알 수 있습니다. N과 Q는 최대 100000이기 때문에 상수가 조금만 커도 바로 시간초과가 날 가능성이 높습니다.

    이를 해결하는 방법은 hld에 사용하는 세그먼트 트리를 bottom-up 방식으로 구현하는 것입니다.

     

    대부분의 사람들이 쓰는 세그먼트 트리는 top-down 방식으로, 전체 구간에서 절반으로 쪼개가며 쿼리를 처리합니다. 하지만 이런 방식으로 쿼리를 처리하게 되면 구간의 길이가 짧을수록 시간이 더 오래 걸리게 되고, heavy-light decomposition을 통해 light edge가 logN개 있을 때 쿼리를 처리하게 된다면 반드시 log^2N번의 연산이 필요하게 되면서 시간이 오래 걸리게 됩니다.

     

    하지만 bottom-up 방식의 경우, 양쪽 끝에서 구간을 잘라나가는 방식으로 쿼리를 처리하고, 구간의 길이가 짧을수록 시간이 덜 걸리게 됩니다. 따라서, 짧은 구간 쿼리를 여러번 처리해야 하는 heavy-light decomposition의 시간을 효과적으로 줄여줄 수 있다는 것을 알 수 있습니다.

     

    bottom-up 방식의 세그먼트 트리에 대한 구현은 다음 링크에 잘 설명되어 있습니다.

    codeforces.com/blog/entry/18051

     

    Efficient and easy segment trees - Codeforces

     

    codeforces.com

    이 방식으로 구현하게 되면 로그가 세제곱임에도 불구하고 상수가 작아 1초 근처로 통과하는 것을 확인할 수 있습니다.

    이렇게 최적화를 해도 정풀은 시복이 더 작아서 순위표에서 거의 꼴등입니다

    tmi) 세그먼트 트리에 저장하는 값을 약간만 바꾸면 트리와 쿼리 7도 풀 수 있습니다.

     

    수정) 구현에 대해 질문하신 분이 있어서 코드를 올렸습니다.

     

    #include <bits/stdc++.h>
    
    using namespace std;
    typedef long long ll;
    bool color[100100];
    struct node{
        int mnb, mnw, c1, c2;
        node(){}
        node(int _mnb, int _mnw, int _c1, int _c2): mnb(_mnb), mnw(_mnw), c1(_c1), c2(_c2) {};
    };
    struct seg{
        int tree1[200200], tree2[200200];
        int sz;
        void update(int idx){
            idx += sz-1;
            tree1[idx] ^= 1, tree2[idx] ^= 1;
            for (;idx>1;idx>>=1){
                tree1[idx>>1] = tree1[idx] | tree1[idx^1];
                tree2[idx>>1] = tree2[idx] & tree2[idx^1];
            }
        }
        int query_or(int l, int r){
            int ret=0;
            l--;
            for (l+=sz, r+=sz;l<r;l>>=1, r>>=1){
                if (l&1) ret |= tree1[l++];
                if (r&1) ret |= tree1[--r];
            }
            return ret;
        }
        int query_and(int l, int r){
            int ret=1;
            l--;
            for (l+=sz, r+=sz;l<r;l>>=1, r>>=1){
                if (l&1) ret &= tree2[l++];
                if (r&1) ret &= tree2[--r];
            }
            return ret;
        }
    }hld_tree;
    struct seg2{
        vector<node> arr, tree;
        vector<pair<int, int>> lazy;
        int sz;
        node combine(node a, node b){
            node ret = a;
            if (ret.mnb>b.mnb){
                ret.mnb = b.mnb, ret.c1 = b.c1;
            }
            else if (ret.mnb == b.mnb) ret.c1 += b.c1;
            if (ret.mnw>b.mnw){
                ret.mnw = b.mnw, ret.c2 = b.c2;
            }
            else if (ret.mnw == b.mnw) ret.c2 += b.c2;
            return ret;
        }
        void init(int i = 1, int l = 0, int r = -1){
            if (r==-1) r = sz-1;
            if (l==r){
                tree[i] = arr[l]; return;
            }
            int m = (l+r)>>1;
            init(i<<1, l, m); init(i<<1|1, m+1, r);
            tree[i] = combine(tree[i<<1], tree[i<<1|1]);
        }
        void propagate(int i, int l, int r){
            tree[i].mnb += lazy[i].first, tree[i].mnw += lazy[i].second;
            if (l!=r){
                lazy[i<<1].first += lazy[i].first, lazy[i<<1].second += lazy[i].second;
                lazy[i<<1|1].first += lazy[i].first, lazy[i<<1|1].second += lazy[i].second;
            }
            lazy[i] = make_pair(0, 0);
        }
        void update(int i, int l, int r, int s, int e, int val1, int val2){
            propagate(i, l, r);
            if (r<s || e<l) return;
            if (s<=l && r<=e){
                lazy[i].first += val1, lazy[i].second += val2; propagate(i, l, r);
                return;
            }
            int m = (l+r)>>1;
            update(i<<1, l, m, s, e, val1, val2); update(i<<1|1, m+1, r, s, e, val1, val2);
            tree[i] = combine(tree[i<<1], tree[i<<1|1]);
        }
        node query(int i, int l, int r, int s, int e){
            propagate(i, l, r);
            if (r<s || e<l) return node(1e9, 1e9, 0, 0);
            if (s<=l && r<=e) return tree[i];
            int m = (l+r)>>1;
            return combine(query(i<<1, l, m, s, e), query(i<<1|1, m+1, r, s, e));
        }
    } st[100100];
    vector<int> adj[100100], g[100100]; ///graph, hld_graph
    pair<int, int> subtree[100100][20]; ///subtree
    bool visited[100100]; ///checking centroid
    int sz[100100], sz2[100100], par[100100], par2[100100], top[100100], in[100100], dep2[100100], db[100100][20], cent_dep[100100], n, seg_timer;
    ///centroid sz, hld sz, centroid tree, hld tree, hld chain, hld dfs order, hld depth, centroid root, centroid depth
    
    int getsize(int s, int pa = -1){ ///getsize (sz)
        sz[s] = 1;
        for (int v:adj[s]) if (!visited[v] && v!=pa) sz[s] += getsize(v, s);
        return sz[s];
    }
    
    int getcent(int s, int pa, int cap){ ///get centroid
        for (int v:adj[s]) if (!visited[v] && v!=pa && (sz[v]<<1)>cap) return getcent(v, s, cap);
        return s;
    }
    
    void getseg(int s, int pa, int rt, int i, int dep, int dfs_dep){ ///subtree, db, st[i].arr initialize **
        db[s][dep] = rt;
        st[i].arr[seg_timer] = node(dfs_dep, 0, 1, 1);
        subtree[s][dep].first = seg_timer+1;
        seg_timer++;
        for (int v:adj[s]) if (!visited[v] && v!=pa){
            if (rt==-1) getseg(v, s, v, i, dep, dfs_dep+1);
            else getseg(v, s, rt, i, dep, dfs_dep+1);
        }
        subtree[s][dep].second = seg_timer;
    }
    
    void getcentree(int s = 1, int pa = -1, int cap = n, int dep = 0){ ///make centroid tree, st[i] initialize, cent_dep initialize **
        getsize(s, pa);
        int cent = getcent(s, pa, cap);
        cent_dep[cent]=dep;
        st[cent].tree.resize(cap<<2); st[cent].lazy.resize(cap<<2); st[cent].arr.resize(cap);
        st[cent].sz = cap;
        seg_timer=0;
        getseg(cent, -1, -1, cent, dep, 1);
        st[cent].init();
        visited[cent]=1;
        if (pa!=-1) par[cent] = pa;
        for (int v:adj[cent]) if (!visited[v]){
            getcentree(v, cent, sz[v], dep+1);
        }
    }
    
    bool visited2[100100];
    void dfs(int s = 1){
        visited2[s]=1;
        for (int v:adj[s]) if (!visited2[v]){
            g[s].push_back(v);
            dfs(v);
        }
    }
    
    void dfs1(int s = 1){
        sz2[s]=1;
        for (int &v:g[s]){
            par2[v]=s; dep2[v]=dep2[s]+1;
            dfs1(v); sz2[s] += sz2[v];
            if (sz2[v]>sz2[g[s][0]]) swap(v, g[s][0]);
        }
    }
    
    int pv;
    void dfs2(int s = 1){ ///hld initialize
        in[s] = ++pv;
        for (int v:g[s]){
            top[v] = v==g[s][0] ? top[s] : v; dfs2(v);
        }
    }
    
    int hld_query(int x, int y){ ///**
        int chk1=0, chk2=1;
        while(top[x] != top[y]){
            if (dep2[top[x]]<dep2[top[y]]) swap(x, y);
            int st = top[x];
            chk1 |= hld_tree.query_or(in[st], in[x]);
            chk2 &= hld_tree.query_and(in[st], in[x]);
            x = par2[st];
        }
        if (dep2[x]>dep2[y]) swap(x, y);
        chk1 |= hld_tree.query_or(in[x], in[y]);
        chk2 &= hld_tree.query_and(in[x], in[y]);
        if (chk1 ^ chk2) return -1;
        return chk1;
    }
    
    void update(int v){ ///**
        hld_tree.update(in[v]);
        for (int i=v;i;i=par[i]){
            if (color[v]) st[i].update(1, 1, st[i].sz, subtree[v][cent_dep[i]].first, subtree[v][cent_dep[i]].second, 1, -1);
            else st[i].update(1, 1, st[i].sz, subtree[v][cent_dep[i]].first, subtree[v][cent_dep[i]].second, -1, 1);
        }
        color[v] = !color[v];
    }
    
    int query(int v){ ///**
        int ret=0;
        for (int i = v;i;i = par[i]){
            int tmp = hld_query(v, i);
            if (tmp==-1) continue;
            if (!tmp){
                ret += st[i].tree[1].c2;
                if (db[v][cent_dep[i]]>0){
                     node tmp_node = st[i].query(1, 1, st[i].sz, subtree[db[v][cent_dep[i]]][cent_dep[i]].first, subtree[db[v][cent_dep[i]]][cent_dep[i]].second);
                     if (!tmp_node.mnw) ret -= tmp_node.c2;
                }
            }
            else{
                ret += st[i].tree[1].c1;
                if (db[v][cent_dep[i]]>0){
                     node tmp_node = st[i].query(1, 1, st[i].sz, subtree[db[v][cent_dep[i]]][cent_dep[i]].first, subtree[db[v][cent_dep[i]]][cent_dep[i]].second);
                     if (!tmp_node.mnb) ret -= tmp_node.c1;
                }
            }
        }
        return ret;
    }
    
    int main(){
        scanf("%d", &n);
        for (int i=0;i<n-1;i++){
            int a, b;
            scanf("%d %d", &a, &b);
            adj[a].push_back(b);
            adj[b].push_back(a);
        }
        hld_tree.sz=n;
        getcentree(); dfs(); dfs1(); dfs2();
        int q;
        scanf("%d", &q);
        while(q--){
            int a, b;
            scanf("%d %d", &a, &b);
            if (a&1) update(b);
            else printf("%d\n", query(b));
        }
        return 0;
    }
    

    댓글

Designed by Tistory.