[BOJ] 10217 KCM Travel

다익스트라, DP


이 문제는 비용과 거리를 모두 고려하며 구현해야 하는 최단경로 문제다. 약간 01배낭 문제랑 다익스트라를 섞어놓은 느낌을 받을 수 있다.

1차시기

내가 처음 생각한 알고리즘

  1. 행, 열을 각각 노드, 비용으로 두는 2차원 배열을 생성(배낭 문제처럼)
  2. 1번 노드부터 연결된 인접 노드의 해당 비용일때의 거리를 저장.
  3. 해당 거리를 key로 하여 min-heap에 push
  4. pop하여 현재 노드로부터 다시 이어지는 인접 노드의 비용을 업데이트 하고 min-heap에 push
    1. heap이 빌 때까지 반복

이렇게 구현했더니 메모리 초과가 났다.

2차시기

나는 여기서 메모리 초과는 heap 때문에 발생했다고 생각하여 heap의 사용을 자제하고 그냥 2차원 배열에서 비용을 1씩 늘리며 해당 값이 INF가 아니라면 해당 노드로 이동해서 인접 노드의 비용을 업데이트 하는 방식으로 구현했다.

이렇게 구현한 코드는 다음과 같고 python3에서는 시간초과 pypy3에서는 통과를 했다.

import sys

def solution(n, cost_limit):
    # 행은 노드 번호 열은 비용
    dp = [[INF] * (cost_limit+1) for _ in range(n+1)]
    dp[1][0] = 0    # 1에서 1로 가는 비용은 0

    for cost in range(cost_limit+1):
        for node in range(1, n+1):
            if dp[node][cost] == INF:
                continue
            for to_cost, to_dist, to_node in routes[node]:
                new_cost = cost + to_cost
                new_dist = dp[node][cost] + to_dist
                if new_cost <= cost_limit:
                    dp[to_node][new_cost] = min(dp[to_node][new_cost], new_dist)
    return min(dp[n]) if min(dp[n]) != INF else "Poor KCM"

if __name__ == "__main__":
    INF = int(1e9)
    test_cases = int(sys.stdin.readline().rstrip())
    for _ in range(test_cases):
        N, M, K = map(int, sys.stdin.readline().rstrip().split())
        routes = [[] for _ in range(N + 1)]
        for _ in range(K):
            U, V, C, D = map(int, sys.stdin.readline().rstrip().split())
            routes[U].append([C, D, V])
        print(solution(N, M))

3차시기

나는 그래서 python3로 통과한분들이 계신가~ 하고 봤더니 있었다. 그래서 코드를 슥 염탐했는데 내 1차시기 코드와 완전 똑같은데 조건 몇 개가 추가되어 있었다.

나는 heap에 중복 데이터가 너무 많이 쌓여서 오류가 난다고 생각했다.

내가 놓친 부분

  • 이미 dp table 에 더 최적의 데이터가 있다면 굳이 heap에 넣을 필요가 없다. 나는 이걸 Heap을 탐색해야 한다고 생각을 했는데 그것이 아니었다. 그냥 heap에 안 넣어버리면 되는 것.
  • 그리고 for문을 통해서 뒤 노드들을 싹 업데이트 해주는 것. 나는 이게 시간 초과가 날 것이라고 생각했었다.

이 두 부분에 대해서 조건을 추가해주니 메모리 초과가 해결이 되었고 Python3에서 동작하는 코드가 되었다.

import sys
import heapq

def solution(n, start, cost_limit):
    # 행은 노드 번호 열은 비용
    dp = [[INF] * (cost_limit+1) for _ in range(n+1)]
    dp[1][0] = 0    # 1에서 1로 가는 비용은 0
    heap = []
    heapq.heappush(heap, (0, 0, start))
    answer = INF
    while heap:
        cur_dist, cur_cost, cur_node = heapq.heappop(heap)
        if cur_dist > dp[cur_node][cur_cost]:
            continue
        # 비용에 구애받지 않고 최단 거리로(heap 에 의해 보장) n에 도달한 경우 stop
        if cur_node == n:
            answer = cur_dist
            break
        for to_cost, to_dist, to_node in routes[cur_node]:
            new_cost = cur_cost + to_cost
            new_dist = cur_dist + to_dist
            # 만약 new cost 가 cost limit 을 넘기면 pass
            if new_cost > cost_limit:
                continue
            # heap 에 굳이 넣을 이유가 없다면 pass
            if dp[to_node][new_cost] <= new_dist:
                continue
            # dp table 값을 업데이트 해야함
            for i in range(new_cost, cost_limit + 1):
                if dp[to_node][i] < new_dist:
                    break
                dp[to_node][i] = new_dist
            heapq.heappush(heap, (new_dist, new_cost, to_node))
    return answer if answer != INF else "Poor KCM"

if __name__ == "__main__":
    INF = int(1e9)
    test_cases = int(sys.stdin.readline().rstrip())
    for _ in range(test_cases):
        N, M, K = map(int, sys.stdin.readline().rstrip().split())
        routes = [[] for _ in range(N + 1)]
        for _ in range(K):
            U, V, C, D = map(int, sys.stdin.readline().rstrip().split())
            routes[U].append([C, D, V])
        print(solution(N, 1, M))