[BOJ] 10217 KCM Travel
다익스트라, DP
이 문제는 비용과 거리를 모두 고려하며 구현해야 하는 최단경로 문제다. 약간 01배낭 문제랑 다익스트라를 섞어놓은 느낌을 받을 수 있다.
1차시기
내가 처음 생각한 알고리즘
- 행, 열을 각각 노드, 비용으로 두는 2차원 배열을 생성(배낭 문제처럼)
- 1번 노드부터 연결된 인접 노드의 해당 비용일때의 거리를 저장.
- 해당 거리를 key로 하여 min-heap에 push
- pop하여 현재 노드로부터 다시 이어지는 인접 노드의 비용을 업데이트 하고 min-heap에 push
- heap이 빌 때까지 반복
이렇게 구현했더니 메모리 초과가 났다.
2차시기
나는 여기서 메모리 초과는 heap 때문에 발생했다고 생각하여 heap의 사용을 자제하고 그냥 2차원 배열에서 비용을 1씩 늘리며 해당 값이 INF가 아니라면 해당 노드로 이동해서 인접 노드의 비용을 업데이트 하는 방식으로 구현했다.
이렇게 구현한 코드는 다음과 같고 python3에서는 시간초과 pypy3에서는 통과를 했다.
import sys
def solution(n, cost_limit):
# 행은 노드 번호 열은 비용
dp = [[INF] * (cost_limit+1) for _ in range(n+1)]
dp[1][0] = 0 # 1에서 1로 가는 비용은 0
for cost in range(cost_limit+1):
for node in range(1, n+1):
if dp[node][cost] == INF:
continue
for to_cost, to_dist, to_node in routes[node]:
new_cost = cost + to_cost
new_dist = dp[node][cost] + to_dist
if new_cost <= cost_limit:
dp[to_node][new_cost] = min(dp[to_node][new_cost], new_dist)
return min(dp[n]) if min(dp[n]) != INF else "Poor KCM"
if __name__ == "__main__":
INF = int(1e9)
test_cases = int(sys.stdin.readline().rstrip())
for _ in range(test_cases):
N, M, K = map(int, sys.stdin.readline().rstrip().split())
routes = [[] for _ in range(N + 1)]
for _ in range(K):
U, V, C, D = map(int, sys.stdin.readline().rstrip().split())
routes[U].append([C, D, V])
print(solution(N, M))
3차시기
나는 그래서 python3로 통과한분들이 계신가~ 하고 봤더니 있었다. 그래서 코드를 슥 염탐했는데 내 1차시기 코드와 완전 똑같은데 조건 몇 개가 추가되어 있었다.
나는 heap에 중복 데이터가 너무 많이 쌓여서 오류가 난다고 생각했다.
내가 놓친 부분
- 이미 dp table 에 더 최적의 데이터가 있다면 굳이 heap에 넣을 필요가 없다. 나는 이걸 Heap을 탐색해야 한다고 생각을 했는데 그것이 아니었다. 그냥 heap에 안 넣어버리면 되는 것.
- 그리고 for문을 통해서 뒤 노드들을 싹 업데이트 해주는 것. 나는 이게 시간 초과가 날 것이라고 생각했었다.
이 두 부분에 대해서 조건을 추가해주니 메모리 초과가 해결이 되었고 Python3에서 동작하는 코드가 되었다.
import sys
import heapq
def solution(n, start, cost_limit):
# 행은 노드 번호 열은 비용
dp = [[INF] * (cost_limit+1) for _ in range(n+1)]
dp[1][0] = 0 # 1에서 1로 가는 비용은 0
heap = []
heapq.heappush(heap, (0, 0, start))
answer = INF
while heap:
cur_dist, cur_cost, cur_node = heapq.heappop(heap)
if cur_dist > dp[cur_node][cur_cost]:
continue
# 비용에 구애받지 않고 최단 거리로(heap 에 의해 보장) n에 도달한 경우 stop
if cur_node == n:
answer = cur_dist
break
for to_cost, to_dist, to_node in routes[cur_node]:
new_cost = cur_cost + to_cost
new_dist = cur_dist + to_dist
# 만약 new cost 가 cost limit 을 넘기면 pass
if new_cost > cost_limit:
continue
# heap 에 굳이 넣을 이유가 없다면 pass
if dp[to_node][new_cost] <= new_dist:
continue
# dp table 값을 업데이트 해야함
for i in range(new_cost, cost_limit + 1):
if dp[to_node][i] < new_dist:
break
dp[to_node][i] = new_dist
heapq.heappush(heap, (new_dist, new_cost, to_node))
return answer if answer != INF else "Poor KCM"
if __name__ == "__main__":
INF = int(1e9)
test_cases = int(sys.stdin.readline().rstrip())
for _ in range(test_cases):
N, M, K = map(int, sys.stdin.readline().rstrip().split())
routes = [[] for _ in range(N + 1)]
for _ in range(K):
U, V, C, D = map(int, sys.stdin.readline().rstrip().split())
routes[U].append([C, D, V])
print(solution(N, 1, M))