본문 바로가기

Problem Solving/백준

백준 4256 트리

https://www.acmicpc.net/problem/4256

 

4256번: 트리

첫째 줄에 테스트 케이스의 개수 T가 주어진다. 각 테스트 케이스의 첫째 줄에는 노드의 개수 n이 주어진다. (1 ≤ n ≤ 1,000) BT의 모든 노드에는 1부터 n까지 서로 다른 번호가 매겨져 있다. 다음

www.acmicpc.net

 

글을 2~3일에 하나씩 올리고 싶은데, 그것도 보통 힘든게 아닌 것 같다. 사실 다른 내용도 많이 담아야 하는데, 그동안 미뤄둔 백준 문제가 너무나 많기에 계속 이렇게 조금씩 올리고 있다.

 

트리를 다루면 전위 순회, 중위 순회, 후위 순회에 대한 것을 금방 맞닥뜨리게 된다. 너무나도 전형적인 문제이기 때문이다. 이 문제가 바로 그것이다. 재귀함수에 대한 이해를 하기도 좋고, 트리구조에 대해서 이해하기에도 좋다.

반드시 고심해서 직접 풀어보고 이해하고 넘어가야하는 문제라고 생각한다.

 

전위순회는 부모노드를 먼저 출력하고 왼쪽 자식노드부터 하나씩 확인하는 방식이다.

후위순회는 반대로 부모노드를 가장 마지막에 출력하며

중위순회는 부모노드를 중간에 출력하게 된다.

 

즉, 트리 노드를 선형화하여 출력할 때 전위, 중위, 후위 순회는 비슷한 형태로 구현되며, 어디서 노드번호를 출력하느냐만 생각해보면 된다. 아니 그런 형태로 만들어야 응용하기가 좋다. 

 

반대로 그렇게 출력된 값들을 보면서 역으로 트리를 만들려면 어떻게 해야할까. 전위순회와 중위순회가 주어진다. 전위순회의 가장 앞자리는 항상 최상단 부모노드이다. 중위순회는 가운데에서 부모노드를 출력하기에 아래 그림과 같이,

중위순회 출력결과에서 3을 찾고, 그 왼쪽과 오른쪽으로 분리해준다. 그렇다면 3의 왼쪽에는 4개의 원소가 있으므로, 전위순회에서 부모노드였던 3 다음 4개가 3의 왼쪽 서브트리가 된다는 것이다. 그럼 이제 [ 6, 5, 4, 8 ]을 들고 다시 또 재귀함수 호출하면 된다. 당연히 6이 가장 앞에 나오기에 6이 부모가 된다. 이제 중위 순회에서 6의 위치를 보고, 6의 왼쪽에는 5가 오른쪽에는 8, 4가 있음을 보면 된다.

 

즉, 다음과 같은 순서로 풀어내면 된다.

  • 전위순회 결과의 가장 앞자리가 부모노드이다. 
  • 중위순회 결과에서 해당 부모노드의 위치를 찾고, 좌측과 우측의 길이를 리턴해준다.
  • 위에서 리턴된 좌측 길이의 가장 앞에 있는 값을 부모노드 왼쪽 자식으로 연결
  • 우측 길이의 가장 앞에 있는 값을 부모노드 오른쪽 자식으로 연결
  • 전위순회 왼쪽파트를 갖고 재귀 함수 다시 호출
  • 전위순회 오른쪽파트로 재귀 함수 호출
  • 기저조건: 원소의 갯수가 1개 또는 좌, 우측 길이가 0일 때

<코드>

C++

binaryTree함수를 통해 트리를 복원한다. 

필자는 중위순회 결과안에서 바로바로 부모노드의 위치를 찾기 위해, idxInorder라는 배열을 두어 각 노드의 중위순회 결과 내 인덱스 값을 저장해두었다. 그리고 binaryTree함수에 들어가는 인자 중 마지막 fs는 좌측 서브트리인지, 우측 서브트리인지를 나타낸다.

root값은 현재 보고 있는 전위순회 결과부분의 가장 앞에 위치한 값이며, 

fs가 0이면 r의 왼쪽 자식이되고

fs가 1이면 r의 오른쪽 자식이 된다.

#include <iostream>
#define endl '\n'
using namespace std;
const int sz=1e3+10;
int t,n,parent[sz],inorder[sz],preorder[sz],idxInorder[sz];
pair<int,int>children[sz];

void binaryTree(int preS,int preE,int inS,int inE,int r,int fs){
    int root=preorder[preS];
    parent[root]=r;
    if(r!=-1){
        if(!fs)children[r].first=root;
        else children[r].second=root;
    }
    int rootIdx=idxInorder[root];
    int cntLeft=rootIdx-inS;
    int cntRight=inE-rootIdx;

    if(cntLeft)binaryTree(preS+1,preS+cntLeft,inS,rootIdx-1,root,0);
    if(cntRight)binaryTree(preS+cntLeft+1,preE,rootIdx+1,inE,root,1);
}
void postOrder(int root){
    int left=children[root].first;
    int right=children[root].second;
    if(left)postOrder(left);
    if(right)postOrder(right);
    cout<<root<<' ';
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(NULL);cout.tie(NULL);

    cin>>t;
    while(t--){
        cin>>n;
        fill(parent,parent+n+1,0);
        fill(inorder,inorder+n+1,0);
        fill(preorder,preorder+n+1,0);
        fill(idxInorder,idxInorder+n+1,0);
        fill(children,children+n+1,make_pair(0,0));
        for(int i=0;i<n;++i)cin>>preorder[i];
        for(int i=0;i<n;++i){
            cin>>inorder[i];
            idxInorder[inorder[i]]=i;
        }
        
        binaryTree(0,n-1,0,n-1,-1,0);
        postOrder(preorder[0]);
        cout<<endl;
    }
    return 0;
}

 

파이썬

사실상 C++과 동일하다.

import sys
si = sys.stdin.readline


def main():
    t = int(si())
    while t:
        t -= 1
        n = int(si())
        parent, children = [0 for _ in range(
            n+1)], [[0, 0] for _ in range(n+1)]
        idx_inorder = [0 for _ in range(n+1)]
        preorder = [int(e) for e in si().split()]
        inorder = [int(e) for e in si().split()]
        for i, e in enumerate(inorder):
            idx_inorder[e] = i

        def binary_tree(pre_s, pre_e, in_s, in_e, ancestor, fs):
            nonlocal parent, children, preorder, idx_inorder, inorder
            root = preorder[pre_s]
            parent[root] = ancestor
            if ancestor != -1:
                if not fs:
                    children[ancestor][0] = root
                else:
                    children[ancestor][1] = root

            idx_root = idx_inorder[root]
            cnt_left, cnt_right = idx_root-in_s, in_e-idx_root
            if cnt_left:
                binary_tree(pre_s+1, pre_s+cnt_left, in_s, idx_root-1, root, 0)
            if cnt_right:
                binary_tree(pre_s+cnt_left+1, pre_e, idx_root+1, in_e, root, 1)

        def print_postorder(root):
            nonlocal children

            left, right = children[root][0], children[root][1]
            if left:
                print_postorder(left)
            if right:
                print_postorder(right)
            print(root, end=' ')

        binary_tree(0, n-1, 0, n-1, -1, 0)
        print_postorder(preorder[0])
        print()


if __name__ == '__main__':
    main()