[BOJ] 1504 특정한 최단경로

다익스트라


해당 문제는 2개의 특정한 지점을 꼭 지나야 하는 최단경로 문제이다.

여기서 입력의 개수가 노드가 800개, 간선이 20만개가 주어진다. 여러 최단경로를 구해야 하므로 처음으로 플로이드 알고리즘을 떠올릴 수 있다. 플로이드는 O(V^3)이므로 800800800 = 약 5억 정도 나오기 때문에 시간 초과가 발생할 확률이 높다.

따라서 다익스트라를 여러번 사용하는 것으로 개선한다.

  • 1 → v1 → v2 → N
  • 1 → v2 → v1 → N

총 3번의 다익스트라 알고리즘의 구현이 필요하다. O(ElogV) * 3 정도면 주어진 시간 내에 연산이 가능하다. 1 → all, v1 → all, v2 → all 을 계산하여 최단 경로를 구하면 된다.

문제에서 무방향 그래프라고 했기 때문에 입력에서 양방향 인접리스트로 초기화 한다. 같은 두 정점사이에 여러 간선이 존재할 수 있으므로 [dict()]형식으로 최소값을 갖는 하나의 간선만 남기도록 구현했다.

코드는 아래와 같다.

import sys
import heapq

INF = int(1e9)

def dijkstra(start, distance):
    q = []
    heapq.heappush(q, (0, start))
    distance[start] = 0
    while q:
        dist, now = heapq.heappop(q)
        if distance[now] < dist:
            continue
        for des, weight in adj_lst[now].items():
            cost = dist + weight
            if cost < distance[des]:
                distance[des] = cost
                heapq.heappush(q, (cost, des))

def solution(n, v1, v2):
    distance = [INF] * (n + 1)
    dijkstra(1, distance)
    one_to_v1 = distance[v1]
    one_to_v2 = distance[v2]

    distance = [INF] * (n + 1)
    dijkstra(v1, distance)
    v1_to_v2 = distance[v2]
    v1_to_n = distance[n]

    distance = [INF] * (n + 1)
    dijkstra(v2, distance)
    v2_to_v1 = distance[v1]
    v2_to_n = distance[n]

    # 사실 무방향 그래프 이므로 v1_to_v2 와 v2_to_v1은 동일하다.
    first_path = one_to_v1 + v1_to_v2 + v2_to_n
    second_path = one_to_v2 + v2_to_v1 + v1_to_n

    answer = min(first_path, second_path)
    if answer >= INF:
        answer = -1

    return answer

if __name__ == "__main__":
    N, E = map(int, sys.stdin.readline().rstrip().split())
    adj_lst = [dict() for _ in range(N + 1)]
    for _ in range(E):
        a, b, c = map(int, sys.stdin.readline().rstrip().split())
        try:
            adj_lst[a][b] = min(adj_lst[a][b], c)
        except:
            adj_lst[a][b] = c
        try:
            adj_lst[b][a] = min(adj_lst[b][a], c)
        except:
            adj_lst[b][a] = c
    v1, v2 = map(int, sys.stdin.readline().rstrip().split())
    print(solution(N, v1, v2))