https://www.acmicpc.net/problem/2463
좋은 문제이다. 가장 좋은 문제 몇 손가락 안에 들 정도로.
발상의 전환이 필요하고, 유니온 파인드의 의미에 대한 이해가 필요하다. (최소 스패닝 트리 알고리즘 안에 유니온 파인드가 포함된다.)
유니온 파인드에서 union 으로 연결된 정점들은 일종의 그룹을 만들게 된다. 자연히 그들은 루트(대표)를 갖는다. 그리고 그 루트 정점의 식별자가 해당 그룹의 식별자이기도 하다.
- 한 정점이 어떤 그룹에 포함되는지
- 어떤 정점과 다른 정점이 같은 그룹에 있는지 (연결되어 있는지)
- 각 그룹에는 몇 개의 정점이 포함되는지
- 총 몇개의 그룹이 있는지
- 어떠한 그룹에도 포함되지 않은 독립된 정점이 있는지, 몇 개 있고 그들이 누구인지.
위와 같은 다양한 정보를 얻을 수 있다.
문제에서 중요한 제한 사항이 있다. 바로 모든 경로의 값이 다르다는 점이다. 이것이 중요한 힌트이자 풀이의 근거가 된다. 모든 경로의 값이 다르기에 그래프의 형상이 관련 없어진다. 누구와 누가 연결되어있는지만 알면 된다. 같은 가중치를 가진 경로가 2개 이상 있는 그래프라면 몇 개의 케이스가 있는지 찾을 때 문제가 있으리라 본다.
또한, 가장 작은 가중치의 경로부터 순서대로 제거해야 한다.
가장 기본적인 최소 스패닝 트리를 문제를 해결할 때, 경로를 오름차순 정렬하고 (내림차순 정렬하면 반대로 최대 스패닝 트리가 될 것이다.), 하나씩 연결하면서 정점들을 이어준다.
이 문제에서는 두 정점간 Cost(u,v)를 구하기 위해 반대로 가장 작은 가중치의 경로부터 끊어준다. 절대 끊기면 안 되는 마지막 경로까지 끊었을 때의 총비용을 구하는 것이다. 근데 처음에 한번 구하는 u1, v1간의 정보를 다른 Cost(u2,v2)를 구할 때 전혀 활용할 수 없다. u2, v2간의 경로들은 u1, v1간의 경로와 전혀 중복되지 않을 수도 있기 때문이다.
총 n개의 정점이 있으면 (n-1)!개의 Cost(u,v)를 구해야 하는데 n이 최대 10만 개다...
작은 경로부터 하나씩 끊여서 두 정점 간의 연결 여부를 BFS 등으로 매번 확인할 수도 없다.
따라서, 다른 접근이 필요하며 고민을 많이 했다.
왼쪽의 그림은 문제에서 주어진 그래프이다.
뭔가 규칙성을 발견하기 위해 직접 가장 원시적인 방법을 시도해보는 것도 방법이 될 수 있다. 그래서 다소 무리일 수도 있지만, 모든 Cost(u,v)를 전부 구해보았다...
값을 다 더해서 계산해버리면 규칙성을 발견하기 어렵다. 그래서 우리가 수학이나 물리 등에서 기호를 사용하거나 계산되지 않은 숫자들을 그대로 가져가는 것이다.
1~2 : 2+3+4+5+6+10
1~3 : 2+3+4+5+6
1~4 : 2+3+4+5
1~5 : 2+3+4
1~6 : 2+3+4+5+6
2~3: 2+3+4+5+6
2~4 : 2+3+4+5
2~5 : 2+3+4
2~6 : 2+3+4+5+6
3~4 : 2+3+4+5
3~5 : 2+3+4
3~6 : 2+3+4+5+6+10+15
4~5 : 2+3+4
4~6 : 2+3+4+5
5~6 : 2+3+4
위에서 보다시피, 뭔가 2, 3, 4 등의 작은 가중치를 가진 경로가 자주 등장한다. 자주 등장하는 정도가 아니라 모든 Cost(u,v)에 포함되어 있다. 즉, 중복되고 있다는 것이다.
아까 위에서 "근데 처음에 한번 구하는 u1, v1간의 정보를 다른 Cost(u2,v2)를 구할 때 전혀 활용할 수 없다."라고 언급했다. 중복을 제거하려면 일단 뭔가 가장 가중치가 큰 경로부터 처리하면 될 것 같다는 생각이 든다.
가중치가 큰 순서로 경로를 하나씩 찾아서 연결해주는 것이다! 그리고 위에서 정리한 Cost(u,v)는 모든 경로 가중치의 합에서 높은 가중치들을 빼준 상태로도 볼 수 있다. 모든 경로 가중치의 합은 2+3+4+5+6+10+15=45 이다.
예를 들어, 2~6 은 45-10-15 이고, 3~6 은 아무것도 빠지지 않았다.
3과 6의 연결을 끊기 위해서는 모든 경로를 다 끊어야 한다. 즉 45의 비용이 든다. 이제 3과 6을 연결해준다.
그리고 그 다음 1과 2의 연결을 끊기 위해서는 30의 비용이 든다. 가중치 15인 경로도 사실 1과 2의 연결을 끊기 전에 끊어도 된다. 하지만 문제에서 가중치가 작은 순서대로 끊기 시작해서 1, 2간 연결이 없을 때까지 끊는다고 하였다. 그래서 15가 상관이 없다.
이 다음이 중요하다. 왼쪽 그림이 현재까지의 상태이다.
다음은 가중치 6의 2와 6을 연결하는 경로이다.
2와 6이 연결되면, 그 둘만 연결되는 게 아니다.
1-3, 1-6, 2-3, 2-6 총 4개의 연결이 생기게 된다.
이 4가지의 경우는 모두 같은 Cost(u,v)를 갖게 된다. 2, 3, 4, 5, 6의 경로들을 모두 끊으면 완전히 끊을 수 있게 된다.
즉 20 곱하기 4의 비용이 누적 총 비용에 합산된다.
여기서 총 4개의 연결이 있다는 사실을 어떻게 알 수 있을까? 여기서 바로 유니온 파인드의 위력이 나오는 것이다.
1과 2를 연결할 때 root값만 바꿔주는 게 아니라 해당 그룹 내 몇개의 정점이 있는지 기록해두면 된다. 물론 루트 정점에 기록해야 한다. 1-2 그룹의 루트가 1이면 cnt[1]=2 이고, 마찬가지로 3-6 그룹의 루트가 3이면 cnt[3]=2 였을 것이다.
두 그룹 간의 연결이 발생하므로 cnt[1] * cnt[3] 개의 케이스가 만들어지게 된다. 그리고 비용 20을 거기에 곱해주면 된다.
그리고 전체 경로의 합에서 연결이 완료된 경로의 가중치를 하나씩 빼면 이번에 추가되는 케이스들의 Cost(u,v)가 몇인지 알 수 있다.
여기서 정말 놀라운 점은 g[u].push_back(v) 등과 같은 그래프 구성을 할 필요가 전혀 없다는 점이다 ㄷㄷ
단순히 어떤 그룹 내에 해당 정점이 포함되는지, 몇 개가 있는지 등만으로 문제가 해결된다는 점이다. (물론 그렇게 풀이가 가능하도록 문제에서 제한을 잘 걸었지만, 그래도 놀랍다.)
<알고리즘>
최소 스패닝 트리, 유니온 파인드
<코드>
자바스크립트
const fs = require('fs')
const input = fs.readFileSync('/dev/stdin').toString().trim().split('\n')
let ans = 0
const mod = 1e9
const [n, m] = input[0].split(' ').map(e => Number(e))
const root = []
const cnt = []
for (let i = 0; i <= n; ++i) {
root.push(i)
cnt.push(1)
}
let t = 0
const edges = []
for (let i = 0; i < m; ++i) {
const [u, v, w] = input[i + 1].split(' ').map(e => Number(e))
edges.push([w, [u, v]])
t += w
}
edges.sort((a, b) => b[0] - a[0])
const find = a => {
if (root[a] === a) return a
return root[a] = find(root[a])
}
const union = (a, b) => {
root[b] = a
cnt[a] += cnt[b]
}
for (const edge of edges) {
let [w, [u, v]] = edge
u = find(u)
v = find(v)
if (u === v) {
t -= w
continue
}
ans += ((t * cnt[u]) % mod * cnt[v]) % mod
union(u, v)
t -= w
}
console.log(ans % mod)
파이썬
from re import A
import sys
si = sys.stdin.readline
n, m = [int(e) for e in si().split()]
t, ans, mod, edges=0, 0, int(1e9), []
root = [i for i in range(n+1)]
edges = [1 for _ in range(n+1)]
def find(a):
if root[a] == a:return a
root[a] = find(root[a])
return root[a]
def union(a, b):
root[a] = b
cnt[b] += cnt[a]
for _ in range(m):
u, v, w = [int(e) for e in si().split()]
edges.append([w, [u, v]])
t += w
edges = sorted(edges, reverse=True)
for edge in edges:
w, [u, v] = edge
u, v = find(u), find(v)
if u != v:
ans += cnt[u]*cnt[v]*t % mod
union(u, v)
t -= w
print(ans % mod)