ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 백준 18799번 풀이 (이상한 편집기)
    문제풀이 2021. 3. 26. 18:14

    문제 링크: www.acmicpc.net/problem/18799

     

    18799번: 이상한 편집기

    스택에 전체 문자열을 추가한 뒤, 한 번 출력하는 것이 최적이다.

    www.acmicpc.net

    문제를 보고 dp문제 느낌이 나길래 dp로 풀었습니다. 제한이 2000이기 때문에 O(N^3)으로는 안 풀리고, O(N^2logN)으로 풀기 위해 slope trick을 사용했습니다. 근데 이걸 slope trick이라 해도 되나?

     

    풀이)

    처음에 naive한 dp 풀이를 만든 후, 최적화하는 방식으로 풀이를 설명하도록 하겠습니다.

    입력으로 들어온 문자열을 str, str의 길이를 n이라 하겠습니다.

     

    O(N^4) Solution

    다음과 같은 dp배열을 정의합시다.

    dp[l][r]: 마지막으로 붙여넣은 문자열이 구간 [l, r]일 때, 편집기 사용의 최소 횟수

     

    이때, 다음 점화식을 만족한다는 것을 확인할 수 있습니다.

    dp[i][j] = min_(1<=k<=i-1){dp[k][i-1] + val(k, i, j)) (인덱스는 1-based 입니다.)

     

    두 부분문자열 [k, i-1]과 [i, j]를 앞부터 비교하면서 t개의 문자가 일치한다고 했을 때, [k, i-1]을 입력한 후 [i, j]를 입력하기 위해 필요한 최소 횟수는 val(k, i, j) = (i-k-t) + (j-i+1-t) + 1이 됩니다. (t개만 남기고 전부 제거, 추가, 붙여넣기)

     

    val(k, i, j)를 naive하게 계산하면 O(N)의 시간이 걸리고, dp[i][j]를 naive하게 계산하면 O(N^2)이 걸리므로, dp table을 전부 채우는데 걸리는 시간은 O(N^4)이 됩니다.

     

    문제의 답은 min_(1<=k<=n)dp[k][n]이 됨을 쉽게 알 수 있습니다.

     

    O(N^3) Solution

    val(k, i, j)를 O(1)에 계산하는 방법에 대해 알아봅시다.

    부분문자열 [k, i-1]과 [i, j]에 대해 일치하는 최대 prefix 길이 t를 구할 때 O(N)의 시간이 걸리기 때문에 val(k, i, j)를 계산하려면 O(N)의 시간이 걸리게 됩니다.

     

    다음과 같은 배열을 전처리합시다.

    cs[i][j]: 부분문자열 [i, n]과 [j, n]에 대해 일치하는 최대 prefix 길이

     

    cs[i][j] 배열을 naive하게 채우면 O(N^3)의 시간이 걸리고, val(k, i, j)를 계산할 때 t = min(cs[k][i], i-k, j-i-1)을 만족하기 때문에 val(k, i, j)를 O(1)에 계산할 수 있습니다.

     

    따라서, 답을 O(N^3)에 계산할 수 있습니다.

     

    O(N^2logN) Solution

    일단 앞에서 정의한 cs배열을 O(N^2)에 전처리합시다.

    str[i] == str[j]일 때, cs[i][j] = cs[i+1][j+1] +1을 만족하고, 그렇지 않을 때는 cs[i][j] = 0을 만족함을 알 수 있습니다.

    따라서, 0 <= d <= n-1인 d에 대해, cs[i][i+d]를 가능한 i에 대해 O(N)에 구할 수 있고, 모든 d에 대해 이를 계산하면 O(N^2)에 cs배열을 전처리할 수 있습니다.

     

    이제, dp를 최적화 해봅시다. k와 i를 고정시키고, j의 값만 변화시키는 상황을 생각해봅시다. j-i+1 = min(cs[i][k], i-k)인 상황을 생각해보면, [k, i-1]을 입력한 후, 스택에 있는 문자열의 길이를 min(cs[i][k], i-k)까지 감소시키고 글을 붙여넣으면 되기 때문에 val(k, i, j) = i-k-min(cs[i][k], i-k)+1이 됩니다. 이때의 j 값을 j0라 하고, val값을 val0라고 합시다. 또한, t = min(cs[i][k], i-k)라고 합시다.

     

    j가 j0보다 x만큼 클 경우, 스택에 있는 문자열의 길이를 t까지 감소시킨후 x개를 추가하고 붙여넣어야 하므로 val(k, i, j) = val+x가 됩니다. 반대로, j가 j0보다 x만큼 작을 경우, 스택에 있는 문자열의 길이를 t까지 감소시킨후 x개를 추가로 감소시켜야 하기 때문에 val(k, i, j) = val+x가 됩니다.

     

    이를 종합해보면, 고정된 k, i에 대해 val(k, i, j) = val0 + |j-j0| 라는 식을 만족함을 알 수 있습니다. 이를 j를 x축에 놓고 그래프를 그려보면, 아래와 같이 절댓값 그래프가 나온다는 것을 확인할 수 있습니다.

     

     

    이제 i만 고정시키고 생각해봅시다. 각 k에 대해, j의 값을 x축에 놓고 val의 값을 y축에 놓은 그래프들을 그리면 다음과 같은 형태의 그림이 나오게 됩니다.

     

     

    이때, dp[i][j]의 값은 모든 k값에 대해 그 중 최소인 것만 고른 그래프이므로, 다음과 같이 최소인 지점들만 남게 됩니다.

     

     

    그래프에서 각 부분은 기울기가 1 또는 -1인 절댓값 그래프의 일부이므로, 그림과 같이 아래로 튀어나온 꼭짓점들의 (j, val) 값을 셋(set)에 저장해놓으면 j값에 대한 이분탐색을 통해 O(logN)에 val값을 빠르게 계산할 수 있습니다.

     

     

    그렇다면 이러한 꼭짓점들을 어떻게 셋에 추가해야할까요? 다음과 같은 방법을 따라가면 O(NlogN)에 셋에 원소들을 추가할 수 있습니다.

     

    1. 각 꼭짓점들을 val값에 대한 오름차순으로 정렬

    2. 꼭짓점 (j, val)을 추가하려고 할 때, 현재 그래프에서 x좌표가 j일 때 y좌표값 tmp 계산

    3. val<tmp일 때만 셋에 꼭짓점 추가

     

    그림을 보면서 작동원리를 이해해봅시다. 일단, 처음에는 아무 그래프도 추가되지 않은 상태이므로 그래프를 1개 추가해줍니다.

     

    그림과 같이 기존에 있던 그래프에서 x좌표가 j일 때 y좌표 값이 추가하려는 그래프의 꼭짓점의 y좌표보다 작은 경우, 이 그래프는 그냥 무시하면 됩니다.

     

    추가하려는 그래프(주황색), 기존 그래프 (하늘색)

     

    반대로, 기존에 있던 그래프에서 x좌표가 j일 때 y좌표 값이 추가하려는 그래프의 꼭짓점의 y좌표보다 큰 경우, 그 그래프의 꼭짓점을 셋에 추가해줘야 합니다. 이때, 꼭짓점들을 사전에 y좌표에 대해 오름차순으로 정렬을 해놓았기 때문에, 셋에서 다른 꼭짓점을 제거해야 하는 상황은 발생하지 않습니다.

     

    추가하려는 그래프(주황색), 기존 그래프(하늘색), 변경된 그래프(연두색)

     

    따라서, 이 방식을 이용하면 O(NlogN)에 대해 dp[i][j]의 그래프를 얻을 수 있고, 각 j에 대해 dp[i][j]를 O(logN)에 계산할 수 있습니다.

     

    이를 모든 i에 대해 하게 되면 O(N^2logN)에 dp[i][j]의 값을 얻을 수 있게 되고, 문제가 풀립니다.

     

     

    소스코드 

    #include <bits/stdc++.h>
    
    using namespace std;
    typedef long long ll;
    int dp[2020][2020], cs[2020][2020];
    char a[2020];
    vector<pair<int, int>> vec;
    set<pair<int, int>> st;
    
    bool comp(pair<int, int> &a, pair<int, int> &b){
        return a.second<b.second;
    }
    
    ll calc(int x){
        auto pos = st.lower_bound(make_pair(x, -1e9));
        ll val1 = 1e9, val2 = 1e9;
        if (pos != st.begin()){
            val1 = (*(--pos)).second + (x-(*pos).first); ++pos;
        }
        if (pos != st.end()) val2 = (*pos).second + ((*pos).first-x);
        return min(val1, val2);
    }
    
    int main(){
        int n=0;
        scanf("%s", a);
        for (int i=0;a[i];i++) n++;
        for (int d=0;d<n;d++){
            for (int i=n-1-d;i>=0;i--){
                if (a[i] == a[i+d]) cs[i][i+d] = cs[i+1][i+d+1]+1;
                else cs[i][i+d] = 0;
            }
        }
        for (int j=0;j<n;j++) dp[0][j] = j+2;
        for (int i=1;i<n;i++){
            vec.clear(); st.clear();
            for (int z=0;z<i;z++) vec.push_back(make_pair(min(cs[z][i], i-z), dp[z][i-1] + max(i-z-cs[z][i], 0) +1));
            sort(vec.begin(), vec.end(), comp);
            for (auto &p:vec) if (st.empty() || calc(p.first)>p.second) st.insert(p);
            for (int j=i;j<n;j++) dp[i][j] = calc(j-i+1);
        }
        int ans = 1e9;
        for (int i=0;i<n;i++) ans = min(ans, dp[i][n-1]);
        printf("%d\n", ans);
        return 0;
    }

    댓글

Designed by Tistory.