[백준] 7812 - 중앙 트리
https://www.acmicpc.net/problem/7812
7812번: 중앙 트리
입력은 여러 개의 테스트 케이스로 이루어져 있다. 각 테스트 케이스의 첫 줄에는 트리의 정점의 수 n이 주어진다. (1 ≤ n ≤ 10,000) 각 정점은 0번부터 n-1번까지 번호가 붙여져 있다. 다음 n-1개 줄
www.acmicpc.net
모든 경우를 탐색하는 방법으로는 테스트 케이스 당 O(N^2)이 소요돼 풀 수 없다.
따라서 각 정점을 한 번씩만 탐색하는 방식으로 문제를 해결해야 한다.
예제 1을 통해 알아보자.
먼저 중앙 정점을 A에서 시작한다고 생각해 보자.
처음 dfs 한 번을 통해서 A에서 모든 정점까지의 거리(sum)를 구할 수 있다.
그다음 A와 연결된 B를 중앙 정점으로 설정하면 어떻게 될까?
여기서는 간선에 집중해 볼 필요가 있다.
즉, 중앙 정점을 A에서 B로 옮길 때 간선 A-B를 제외하고는 변화가 없다.
그렇다면 A-B 간선이 몇 번 쓰이게 되는지 알아보자.
- A를 중앙 정점으로 설정했을 때
- A-B, A-C, A-D, A-E 4번
- B를 중앙 정점으로 설정했을 때
- A-B 1번
A-B 간선이 4번 -> 1번으로 변화한 것을 알 수 있다.
왜 이런 변화가 생겼을까?
그래프를 그려보면 쉽게 알 수 있듯이, B의 자식 노드의 개수와 관련이 있다.
B의 자식 노드를 탐색하려면 무조건 A-B 간선을 지나쳐야 하기 때문이다.
따라서 중앙 정점을 A -> B로 변화할 때 총비용은 다음과 같이 계산할 수 있다.
c를 B를 포함한 자식 노드의 개수라고 할 때,
sum - (n - c) * 2 + c * 2
식을 해석해 보면 이동하려는 정점 서브트리의 노드 개수만큼 감소되고, 이전 정점 서브트리의 노드 개수만큼 증가된다.
이 식을 일반화하면 다음과 같다.
next_cost = now_cost - next_subtree_child_cnt * weight + (n - next_subtree_child_cnt) * weight;
#include <climits>
#include <iostream>
#include <vector>
using namespace std;
vector<vector<pair<int, long long>>> v;
vector<long long> dis, dp, child;
int n;
void go(int node, int pre, long long sum) {
dis[node] = sum;
for (auto &i : v[node]) {
if (i.first != pre) {
go(i.first, node, sum + i.second);
}
}
}
int getChild(int node, int pre) {
bool flag = false;
int res = 0;
for (auto &i : v[node]) {
if (i.first != pre) {
flag = true;
res += getChild(i.first, node);
}
}
return child[node] = res + 1;
}
void go2(int node, int pre, long long now) {
dp[node] = now;
for (auto &i : v[node]) {
if (i.first != pre) {
long long c = child[i.first];
go2(i.first, node, now + (n - c) * i.second - c * i.second);
}
}
}
int main() {
while (1) {
cin >> n;
if (!n) {
break;
}
dis.clear();
dis.resize(n);
v.clear();
v.resize(n);
dp.clear();
dp.resize(n);
child.clear();
child.resize(n);
for (int i = 0; i < n - 1; i++) {
int a, b, c;
cin >> a >> b >> c;
v[a].push_back({b, c});
v[b].push_back({a, c});
}
go(0, -1, 0);
long long sum = 0;
for (int i = 0; i < n; i++) {
sum += dis[i];
}
getChild(0, -1);
go2(0, -1, sum);
long long ans = LLONG_MAX;
for (int i = 0; i < n; i++) {
ans = min(ans, dp[i]);
}
cout << ans << '\n';
}
}