[BOJ] 11049 행렬 곱셈 순서

DP


행렬 [N, M] x [M, K] 에 대한 곱셈 연산의 수는 N x M x K 이다.

여러개의 행렬이 쭉 있을때 뭐 부터 곱해야 행렬 곱셈 연산의 수가 최소가 될지 구하는 문제이다.

아래 사진과 같은 행렬이 주어졌다고 생각해보자.

boj11049_1.jpeg

위 사진이 최소 값이라는 것은 아니고 문제를 푸는 방식에 대한 이해를 해본 것이다. 자세히 들여다 보면 행렬에 대한 모든게 필요한 것이 아니라 row, col 에 대한 숫자 정보가 순서를 해치지 않고 연산이 되면 되는 것이다.

따라서 위 행렬 정보를 리스트 하나로 관리 하면 된다.

  • [ 5, 3, 2, 6, 8, 3 ]

이제 이를 통해서 어떻게 곱 연산 수의 최소값을 찾을 수 있을지 알아보자.

해당 문제를 풀기 위해 구해야 할 것은 0~len(list) 까지의 곱을 어떻게 최소화 할 것인가를 구해야 하므로 이를 f(0, N-1) 로 설정한다. 그렇게 되면 풀어야 할 문제는 아래 사진과 같다.

boj11049_2.jpeg

하위 sub problem에서 어떻게 잘라서 곱했을 때 최소값이 나왔는지를 이미 알고있다고 가정한다면 위 식을 구할 수 있다. 따라서 DFS, 재귀적으로 위 트리를 순회하며 메모이제이션을 활용하면 해결할 수 있다.

위 사진의 점화식에 오류가 있어서 식을 다시 쓰면 아래와 같다 (i 가 아니라 mid임)

  • f(row, col) = min( { mid = row + 1 ~ col - 1} | f(row, mid) + f(mid, col) + (list[row] x list[mid] x list[col])

이를 코드로 구현하면 다음과 같다.

구현


함정이 하나 있었다. 행렬이 1개일 때도 만족해야 한다. 따라서 예외처리를 해주었다.

import sys

def matrix_size_mul(num1, num2, num3):
    return num1 * num2 * num3

# Top-Down DP
def dp(row, col):
    if col - row < 2:
        return 0
    if col - row == 2:
        return matrix_size_mul(numbers[row], numbers[row+1], numbers[row+2])    # row + 2 == col

    mem[row][col] = INF
    for mid in range(row+1, col):
        if not mem[row][mid]:
            mem[row][mid] = dp(row, mid)
        if not mem[mid][col]:
            mem[mid][col] = dp(mid, col)

        mem[row][col] = min(
            mem[row][col],
            mem[row][mid] + mem[mid][col] + matrix_size_mul(numbers[row], numbers[mid], numbers[col])
        )

    return mem[row][col]

if __name__ == "__main__":
    INF = int(1e9)
    N = int(sys.stdin.readline().rstrip())
    # 매트릭스 숫자 받는 부분
    numbers = []
    r, c = map(int, sys.stdin.readline().rstrip().split())
    if N == 1:
        print(r * c)
    else:
        numbers.append(r)
        numbers.append(c)
        for _ in range(N-1):
            r, c = map(int, sys.stdin.readline().rstrip().split())
            numbers.append(c)

        # TOP-DOWN DP
        mem = [[0]*len(numbers) for _ in range(len(numbers))]
        print(dp(0, len(numbers)-1))