https://www.acmicpc.net/problem/1967
트리의 기본 성질 중 하나를 알고 있으면 풀이를 도출하는 게 보다 수월하다. '트리는 어떤 정점을 루트(기준)로 삼아도 트리의 형태로 만들 수 있다.' 이 말이 무엇인지 이해하기 위해 아래 그림을 보자.
위의 그림은 문제에 나온 트리이다. 정점 1이 루트에 있다. 앞서 말한 '어떤 정점을 루트로 삼아도 트리의 형태가 된다'를 확인해보자.
위의 그림과 같이 같은 트리를 6번 정점을 루트로 삼아 다시 그려냈다. 똑같이 트리의 형태로 나타낼 수 있다.
이것이 왜 중요하냐면, 문제에서 주어지는 정보로는 어떠한 정점이 루트 정점인지 알 수 없기 때문이다. 하지만 위에서 보았듯이, 어떠한 정점을 루트로 삼아도 상관없다는 것을 알았다. 특히, 문제에서 요구하는 트리의 지름을 구하기 위해서는 아무 정점이나 루트로 삼아도 된다. 문제에서 정점 번호는 항상 1번부터 시작함을 알 수 있다. 이제 1번 정점을 루트로 삼으면 어떤 케이스가 나와도 해결할 수 있다.
필자가 생각한 방식은 다음과 같다. (참고: 노드=정점)
- 꼬리 노드를 찾는다 (자식 노드가 없음).
- 부모 노드로 이동한다.
- 부모 노드는 자신에게 오는 값들 중 가장 큰 값 2개를 저장해둔다.
- 그리고 그 중 가장 큰 것만 자신의 부모에게 올려준다.
- 이런 식으로 모든 정점에 대해 반복한다.
- 이제 모든 정점에 대해 각각이 저장하고 있는 가장 큰 값 2개를 합해서 지름을 구한다.
이제 꼬리 노드를 어떻게 찾을 것인가? 앞서 설명한 루트 찾기가 중요했는데, 아무거나 루트로 삼아도 됨을 알았다. 1번 정점을 (어떤 테스트 케이스든 1번 정점은 존재한다) 루트로 삼아서 자식 노드들로 BFS를 돌아준다. 그리고 각 정점에 대해 자신과 연결된 간선의 개수를 세어준다. 이는 ins[i]에 저장한다. 여기서 눈치를 챘다면, ins[i]=0인 정점이 바로 꼬리 노드임을 알게 될 것이다.
꼬리 노드들을 모아서 큐에 담아주고 위상정렬을 돌아준다. 그리고 각 노드에서 부모 노드로 올라가는 동안 앞서 말한 가장 큰 값 2개 저장 등의 처리를 해준다. 이렇게 가장 루트인 1번 정점까지 가면 모든 작업이 끝난다.
이제 다시 모든 정점을 한번씩 확인하며 최대값(지름)을 찾는다. 사실 위상정렬을 도는 동안 지름을 계속 갱신하는 방식으로 찾아도 상관없다.
<주요 내용>
위상정렬, BFS, 트리, 트리의 지름
이 문제는 간선마다 가중치가 있어서 이렇게 구하지만, 가중치가 동일하거나 없는 그래프에서는 조금 다른 방식으로 풀어낸다. 그리고 그 풀이는 매우 중요하다. 이후 다른 포스팅에서 소개하겠다.
<코드>
C++
#include <iostream>
#include <vector>
#include <queue>
#include <algorithm>
#define pii pair<int,int>
#define mp make_pair
#define endl '\n'
using namespace std;
const int sz=1e4+1;
bool vis[sz];
int n,a,b,c,v[sz][2],ins[sz],mx;
vector<pii>g[sz];
queue<int>q;
int main(){
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
cin>>n;
for(int i=0;i<n-1;++i){
cin>>a>>b>>c;
g[a].push_back(mp(b,c));
g[b].push_back(mp(a,c));
}
vis[1]=1;
q.push(1);
while(!q.empty()){
int curr=q.front();q.pop();
for(auto& nxp:g[curr]){
if(!vis[nxp.first]){
ins[curr]++;
vis[nxp.first]=1;
q.push(nxp.first);
}
}
}
for(int i=1;i<=n;++i)if(!ins[i])q.push(i);
while(!q.empty()){
int curr=q.front();q.pop();
for(auto& nxp:g[curr]){
int& nx=nxp.first,nc=nxp.second;
if(ins[nx]){
ins[nx]--;
if(v[nx][0]<nc+v[curr][0]){
v[nx][1]=v[nx][0];
v[nx][0]=nc+v[curr][0];
}else if(v[nx][1]<nc+v[curr][0]){
v[nx][1]=nc+v[curr][0];
}
if(!ins[nx])q.push(nx);
}
}
}
for(int i=1;i<=n;++i)mx=max(mx,v[i][0]+v[i][1]);
cout<<mx;
return 0;
}