본문 바로가기

Problem Solving/백준

백준 9095 - 1, 2, 3 더하기

www.acmicpc.net/problem/9095

 

9095번: 1, 2, 3 더하기

각 테스트 케이스마다, n을 1, 2, 3의 합으로 나타내는 방법의 수를 출력한다.

www.acmicpc.net

이전 포스팅과 마찬가지로 이번에도 다이나믹 프로그래밍 문제이다. 아주 기본적인 형태이다.

 

DP(Dynamic Programming)문제를 풀 때, 가장 중요한 것 중 하나가 바로 점화식을 구하는 것이다. 점화식이란 말 그대로 어떠한 연속된 연산의 불을 붙여주는 기초 공식이라고 볼 수 있다. 즉, 이전 포스팅(백준 회의실 배정 문제)에서 살짝 언급한 바와 같이, 연속된 연산이 하나의 공식(점화식)하에서 계속 이루어지기 위해서는 모든 정점(상황)에서 항상 같은 규칙이 적용될 수 있어야 한다. 

 

이 문제에서는 어떠한 정수가 주어졌을 때, 그 정수에서 1, 2, 3을 뺀 정수(정점)가 갖는 경우의 수를 그대로 가져와서 합산하는 식으로 점화식을 짜면 된다. 여기서 몇 가지 응용을 해 볼 수 있는데, 아래와 같다. 그리고 아래의 참조 그림들을 통해 점화식을 확인하자.

 

  1.    1, 2, 3이 아닌 다른 수가 주어진 경우
  2.    1, 2, 3 + @ (더 많은 수가 주어진 경우)
  3.    1, 2, 3의 순서가 달라도 등장 횟수가 같으면 하나의 경우로 생각할 때

참조 그림1
참조 그림2

위의 참조 그림1에서 처럼

(1) 1, 2, 3이 아닌 다른 수로 더할 경우, 1, 2, 3을  다른 숫자로 넣어서 점화식을 표현하면 된다. 

(2)의 경우, 더 많은 숫자가 주어졌으므로 그만큼 항의 개수를 늘려주면 된다.

(3)의 경우에는 살짝 더 까다로워진다. 

예를 들어 7을 구하기 위해, 1+2+3+1을 하는 경우와, 1+1+2+3을 하는 경우가 같다고 할 때가 (3)의 경우가 된다.

이럴 때는 그림과 같이 2차원 배열을 사용해야 점화식을 구해야 한다. dp[i][j]에 '현재 i라는 값에 도달하기 위해 바로 직전에 j라는 숫자(또는 j-th숫자)를 더하였다'라는 의미부여를 통한 방법 등으로 구할 수 있다.

 

<주요 내용>

점화식, 다이나믹 프로그래밍(DP), 동적 계획법

 

<코드>

#include <iostream>
#include <cstring> //memset, use 0 or -1
#define endl '\n'
using namespace std;
const int sz=11+1;
int dp[sz],t,n;

int main()
{
    ios::sync_with_stdio(false);cin.tie(NULL);

    cin>>t;
    while(t--){
        cin>>n;
        memset(dp,0,sizeof(dp));
        dp[0]=1;
        for(int i=0;i<=n;++i){
            for(int j=1;j<=3;++j){
                if(i+j<=n){
                    dp[i+j]+=dp[i];
                }
            }
        }
        cout<<dp[n]<<endl;
    }
    return 0;
}

 

파이썬 코드에서는 함수 내부에서 다른 함수를 정의하였다. 그리고 바깥 함수에서 만든 변수 caseNum에 접근하기 위해 내부 함수에서 nonlocal을 사용하였다. 

def main():
    caseNum = {1: 1, 2: 2, 3: 4}
    q = []
    n = int(input())
    for _ in range(n):
        q.append(int(input()))
    # q = [4, 7, 10]

    def caseNumFinder(n):
        nonlocal caseNum
        if n in caseNum:
            return caseNum[n]
        else:
            caseNum[n] = caseNumFinder(n-3) + \
                caseNumFinder(n-2)+caseNumFinder(n-1)
            return caseNum[n]

    for e in q:
        count = caseNumFinder(e)
        print(count)


main()