ps

[백준] 7812 - 중앙 트리

kariskan 2023. 6. 27. 18:39

https://www.acmicpc.net/problem/7812

 

7812번: 중앙 트리

입력은 여러 개의 테스트 케이스로 이루어져 있다. 각 테스트 케이스의 첫 줄에는 트리의 정점의 수 n이 주어진다. (1 ≤ n ≤ 10,000) 각 정점은 0번부터 n-1번까지 번호가 붙여져 있다. 다음 n-1개 줄

www.acmicpc.net

 

모든 경우를 탐색하는 방법으로는 테스트 케이스 당 O(N^2)이 소요돼 풀 수 없다.

따라서 각 정점을 한 번씩만 탐색하는 방식으로 문제를 해결해야 한다.

 

예제 1을 통해 알아보자.

먼저 중앙 정점을 A에서 시작한다고 생각해 보자.

처음 dfs 한 번을 통해서 A에서 모든 정점까지의 거리(sum)를 구할 수 있다.

 

그다음 A와 연결된 B를 중앙 정점으로 설정하면 어떻게 될까?

여기서는 간선에 집중해 볼 필요가 있다.

즉, 중앙 정점을 A에서 B로 옮길 때 간선 A-B를 제외하고는 변화가 없다.

 

그렇다면 A-B 간선이 몇 번 쓰이게 되는지 알아보자.

  1. A를 중앙 정점으로 설정했을 때
    • A-B, A-C, A-D, A-E 4번
  2. B를 중앙 정점으로 설정했을 때
    • A-B 1번

 

A-B 간선이 4번 -> 1번으로 변화한 것을 알 수 있다.

왜 이런 변화가 생겼을까?

그래프를 그려보면 쉽게 알 수 있듯이, B의 자식 노드의 개수와 관련이 있다.

B의 자식 노드를 탐색하려면 무조건 A-B 간선을 지나쳐야 하기 때문이다.

따라서 중앙 정점을 A -> B로 변화할 때 총비용은 다음과 같이 계산할 수 있다.

c를 B를 포함한 자식 노드의 개수라고 할 때,

sum - (n - c) * 2 + c * 2

식을 해석해 보면 이동하려는 정점 서브트리의 노드 개수만큼 감소되고, 이전 정점 서브트리의 노드 개수만큼 증가된다.

 

이 식을 일반화하면 다음과 같다.

 

next_cost = now_cost - next_subtree_child_cnt * weight + (n - next_subtree_child_cnt) * weight;

 

#include <climits>
#include <iostream>
#include <vector>
using namespace std;

vector<vector<pair<int, long long>>> v;
vector<long long> dis, dp, child;
int n;

void go(int node, int pre, long long sum) {
    dis[node] = sum;
    for (auto &i : v[node]) {
        if (i.first != pre) {
            go(i.first, node, sum + i.second);
        }
    }
}

int getChild(int node, int pre) {
    bool flag = false;
    int res = 0;
    for (auto &i : v[node]) {
        if (i.first != pre) {
            flag = true;
            res += getChild(i.first, node);
        }
    }
    return child[node] = res + 1;
}

void go2(int node, int pre, long long now) {
    dp[node] = now;
    for (auto &i : v[node]) {
        if (i.first != pre) {
            long long c = child[i.first];
            go2(i.first, node, now + (n - c) * i.second - c * i.second);
        }
    }
}

int main() {
    while (1) {
        cin >> n;
        if (!n) {
            break;
        }
        dis.clear();
        dis.resize(n);
        v.clear();
        v.resize(n);
        dp.clear();
        dp.resize(n);
        child.clear();
        child.resize(n);

        for (int i = 0; i < n - 1; i++) {
            int a, b, c;
            cin >> a >> b >> c;
            v[a].push_back({b, c});
            v[b].push_back({a, c});
        }
        go(0, -1, 0);
        long long sum = 0;
        for (int i = 0; i < n; i++) {
            sum += dis[i];
        }
        getChild(0, -1);
        go2(0, -1, sum);
        long long ans = LLONG_MAX;
        for (int i = 0; i < n; i++) {
            ans = min(ans, dp[i]);
        }
        cout << ans << '\n';
    }
}