알고리즘/Graph 그래프

[백준] 1167번: 트리의 지름 - java

아뵹젼 2023. 2. 18.

트리 지름

처음에는 아무 생각 없이 1~v개의 모든 정점에 대하여 가장 먼 노드를 찾기 위해 DFS를 실행했다.

그러나 이 문제는 V가 최대 100000개로, 모든 노드에 대해 DFS를 실행하였을 때 O(n^2) 로 100000x100000 이라는 시간초과를 낳게 되었다.

따라서 트리의 지름에 대해 더 깊은 이해가 필요하였다.

 

 

(출처) 트리의 지름에 대한 고찰 : https://mygumi.tistory.com/226

 

트리란 무엇인가? 트리란 노드 N개에 간선 N-1개가 싸이클 없이 존재하는 그래프이다.

트리에서는 어떤 두 노드를 선택해도 둘 사이에 경로가 항상 하나만 존재하게 된다.

트리에서 어떤 두 노드를 선택해서 양쪽으로 쫙 당길 때, 가장 길게 늘어나는 경우가 있을 것이다.

이럴 때 트리의 모든 노드들은 이 두 노드를 지름의 끝 점으로 하는 원 안에 들어가게 된다.

 

 

트리의 지름을 구성하는 노드들은 항상 리프 노드라는 것을 알 수 있다.

만약 4번 노드에서 출발하여 다른 리프 노드까지의 거리를 구해보면 트리의 지름이 될 수 없는 것을 확인할 수 있다.

그렇다면 트리의 지름을 구성하는 두 개의 리프 노드는 어떻게 구할 수 있을까?

 

위 트리의 임의 노드에서 출발하여 다른 노드까지의 거리가 가장 먼 노드를 찾아봤을 때 9번 노드 또는 12번 노드가 나오게 된다. 

임의의 노드를 기준으로 가장 먼 노드를 찾았다면 이 노드는 트리의 지름을 구성하는 노드 중 하나가 된다.

 

즉, 트리 내의 어떤 임의의 노드를 시작 노드로 하더라도, 해당 노드에서 가장 먼 정점은 노드 a혹은 노드b가 될 것이다.

이러한 특징을 이용해서 풀 수 있는 문제이다.

 

 

따라서 트리의 지름 원리를 이해한다면,

1. 임의의 노드(x)에서 가장 먼 노드(a)를 찾는다. 

2. 해당 노드(a) 에서 가장 먼 노드(b)를 찾는다.

=> a~b가 트리의 지름이 될 것이다.

 

 

위 명제에 대한 증명은 다음과 같다.
에서 가장 먼 노드를  , 와 의 거리를 트리의 지름이라고 하자.

인 경우 : 에서 가장 먼 가 가 되므로 참이다.

인 경우 : 1번과 동일하다.

인 경우 : 1번과 동일하다.

인 경우 : 1번과 동일하다.

와 를 연결한 선분과, 와 를 연결한 선분이 교차하는 경우 :

  • 1. x에서 가장 먼 노드가 y이려면, y는 max(d3,d4) 이상 이여야 한다.
  • 2. 그러나 해당 트리에서 가장 긴 노드간의 거리는 a~b인 d3+d4이기 때문에 a~y 간의 거리, b~y 간의 거리는 d3+d4 이하 여야 한다.
  • => 이 조건을 모두 성립하기 위해서는 y는 a혹은 b가 되야 한다.

 

6. x와 y를 연결한 선분과, a와 b를 연결한 선분이 교차하지 않는 경우 :

  • x에서 가장 거리가 긴 노드가 y가 되기 위해서는 d2는 max(d3+d4, d3+d5) 이상 이 되어야 한다.
  • 그러나 트리에서 가장 긴 노드간의 거리는 a~b인 d4+d5이므로, d1+d2는 max(d1+d3+d4, d1+d3+d5) 이하여야 한다.
  • => 두 조건을 모두 만족하기 위해서 y는 a 혹은 b가 되어야 성립한다.

 

이로 인해 
1. 임의의 정점로부터 가장 먼 정점
2.로부터 가장 먼 정점

  •  의 거리가 트리의 지름이 된다.

= > 성립함을 알게 되었다!!!!

 

 

즉, 2번의 DFS만으로 문제를 풀 수 있게 되는 것이다.

-> 모든 노드에 대해 DFS를 탐색했을 때보다 O(n^2) -> O(n) 으로 엄청난 시간 단축이 생기게 되었다.

 

 

 

나의풀이

import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.StringTokenizer;

public class Main {

    static ArrayList<Node>[] list;
    static boolean[] visited;
    static int max_cost = 0;
    static int node; // 가장 멀리 떨어져있는 노드

    public static class Node {
        int v;
        int w;

        public Node(int v, int w) {
            this.v = v;
            this.w = w;
        }
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringTokenizer st;
        int n = Integer.parseInt(br.readLine());
        list = new ArrayList[n+1];
        visited = new boolean[n+1];
        for(int i = 1; i < n + 1; i++) {
            list[i] = new ArrayList<Node>();
        }

        for (int i = 0; i < n; i++) {
            st = new StringTokenizer(br.readLine());
            int e = Integer.parseInt(st.nextToken());
            while (true) {
                int v = Integer.parseInt(st.nextToken());
                if (v==-1)  break;
                int w = Integer.parseInt(st.nextToken());
                list[e].add(new Node(v,w));
            }
        }

        visited[1] = true;
        dfs(1,0); // 정점1에서 가장 먼 노드 찾기

        Arrays.fill(visited, false);

        visited[node] = true;
        dfs(node, 0); // 정점1에서 가장 먼 노드에서 가장 먼 노드 찾기

        System.out.println(max_cost);
    }

    public static void dfs(int x, int cost){
        if (cost>max_cost){
            max_cost = cost;
            node = x;
        }

        for(int i=0; i<list[x].size(); i++){
            Node n = list[x].get(i);
            if (!visited[n.v]){
                visited[n.v] = true;
                dfs(n.v, cost+n.w);
            }
        }
    }
}

 

 

댓글