이 문제는 알고리즘 & 자료구조를 공부한다면, 반드시 한 번쯤은 거쳐가고, 거쳐야만 하는 기본적(?)이면서도 중요한 문제라고 생각한다. 그리고 다양한 풀이가 있을 수 있어서 좋다.
위의 그림은 문제에서 주어진 예제를 다루고 있다. 항상 그렇듯이 이 문제에서도 정렬의 힘을 다시 느낄 수 있다. 입력을 받은 후 정렬하고, 총 3개의 수를 뽑아야 하므로 루프 2개로 2개를 뽑아준다 (여기까지 O(n^2)). 위의 예에서는 처음에 -97과 -6이 선택되고 그 합은 -103이 된다. (그림에서는 오타가 있었다 ㅠㅠ 양해 부탁드림) 이제 103이 있어야 0을 만들 수 있는 것이다. 그렇다면 103라는 값을 나머지 범위 내에서 이분탐색을 통해서 찾는 것이다. 그리고 이분탐색을 행할 수 있는 이유는 정렬이 되어있기 때문이다. 이분탐색은 로그 시간이 필요하므로 총 시간 복잡도는 O(n^2 log n)이 된다. 그리고 주의할 점은, 103라는 정확한 값을 찾는게 아니라, 전부 합해서 0에 가까운 값을 찾는 것이므로 한 칸 앞의 값도 봐야 한다는 것이다. 정렬이 되어있기에 idx와 그 한 칸 앞의 값, 총 2개만 확인하면 0에 가장 가까운 합을 구할 수 있다.
int idx=lower_bound(arr,arr+n,target)-arr;
// check idx-1 !
위의 코드와 같이 lower_bound등을 통해 이분탐색으로 목표값의 인덱스를 찾았다면 그 한 칸 앞이 적합 범위를 벗어나지 않는 한 확인해봐야 한다는 것이다.
필자는 C++은 이렇게 풀어냈지만 파이썬으로는 시간초과가 났다. 코드를 조금 더 다듬을 수도 있겠지만, 알고리즘 자체를 더 고치기로 하였다. 그래서 생각한 것이 투 포인터이다
위의 그림과 같이 이번에는 한개의 수만 루프로 선택한다. 물론 루프는 n-2까지만 돌아야 함에 주의한다. 한 개를 정하고 남은 나머지 범위에 대해서 양 끝 값을 잡고 투 포인터를 행하면 된다. 위의 예시처럼 s=-6, e=98이라면 합은 92가 된다. 앞서 -97을 선택하였고, 97이 있어야 0이 되는데, 92가 나왔으니 값이 모자라다. 여기서 바로 정렬의 힘이 발휘된다! 이 수열은 정렬되어 있기에 뒤로 갈수록 큰 수가 나온다. 그말은 e를 줄이면 합이 작아지고, s를 키우면 합이 커진다는 뜻이다. 우리는 더 큰 값을 원하기 때문에 이번에 s++을 해줘야 하는 것이다.
이런 식으로 인덱스 0부터 n-2까지 루프를 돌고, 매 회차마다 투 포인터를 통해 나머지 범위를 전부 확인하기에 총 시간복잡도는 O(n^2)이 된다. 그리고 이렇게 해결하는 것이 더 빠른 방법이다
<주요 내용>
투 포인터, 정렬, 이분탐색
<코드>
C++
#include <iostream>
#include <algorithm>
#define endl '\n'
using namespace std;
typedef long long ll;
const int sz=5e3;
int n,arr[sz],idx,ans1,ans2,ans3;
ll temp,mn=3e9+1;
bool flag;
int main()
{
ios::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
cin>>n;
for(int i=0;i<n;++i)cin>>arr[i];
sort(arr,arr+n);
for(int i=0;i<n-2;++i){
for(int j=i+2;j<n;++j){
temp=0-(arr[i]+arr[j]);
idx=lower_bound(arr+i+1,arr+j-1,temp)-arr;
if(idx>i+1 && abs(arr[idx-1]-temp)<abs(arr[idx]-temp))idx--;
temp=arr[idx]-temp;
if(abs(temp)<mn){
mn=abs(temp);
ans1=arr[i],ans2=arr[idx],ans3=arr[j];
}
if(!abs(temp)){
flag=1;
break;
}
}
if(flag)break;
}
cout<<ans1<<' '<<ans2<<' '<<ans3<<endl;
return 0;
}
아래는 파이썬 코드이다.
# O(N^2)
import sys
import bisect
si = sys.stdin.readline
def main():
n = int(si())
solutions = [int(e) for e in si().split()]
solutions.sort()
if solutions[0] <= 0 and solutions[-1] <= 0:
print(solutions[-3], solutions[-2], solutions[-1])
return
if solutions[0] >= 0 and solutions[-1] >= 0:
print(solutions[0], solutions[1], solutions[2])
return
mx, flag = int(3e9+1), False
for idx in range(n-2):
target, s, e = -solutions[idx], idx+1, n-1
while s < e:
temp = solutions[s]+solutions[e]
total = temp+solutions[idx]
if abs(total) < mx:
mx = abs(total)
ans1, ans2, ans3 = solutions[idx], solutions[s], solutions[e]
if not total:
flag = True
break
if temp < target:
s += 1
else:
e -= 1
if flag:
break
print(ans1, ans2, ans3)
if __name__ == '__main__':
main()