Multiply Chain Matrix Algorithm
MCM algorithm이란?
두 개 이상의 행렬을 곱할 때, 곱하기 연산을 최소로 할 수 있게 적절히 계산 순서를 바꿔주는 문제이다.
행렬의 곱셈은 다음과 같이 결합법칙이 성립된다.
A x B x C = (A x B) x C = A x (B x C)
하지만 행렬을 곱하는 순서에 따라 곱하는 횟수가 달라지게 된다.
예를 들어,
A = 20 x 1, B = 1 x 30, C = 30 x 10, D = 10 x 10 의 크기를 가진 행렬들이 있다고 가정하자
A x B x C x D의 값을 구하려면 적절히 괄호를 쳐서 아래와 같은 경우의 수를 만들 수 있다.
- ((A x B) x C ) x D) = (20 x 1 x 30) + (20 x 30 x 10) + (20 x 10 x 10) = 8,600
- A x (B x (C x D)) = (30 x 10 x 10) + (1 x 30 x 10) + (20 x 1 x 10) = 3,500
- (A x B) x (C x D) = (20 x 1 x 30) + (30 x 10 x 10) + (20 x 30 x 10) = 9,600
- (A x ((B x C) x D) = (1 x 30 x 10) + (1 x 10 x 10) + (20 x 1 x 10) = 600
마지막 경우 처럼 곱셈을 실행하면 600번의 연산으로 행렬 곱 연산을 끝 마칠 수 있는 것이다.
즉, 연쇄행렬 최소곱셈 알고리즘(MCM)은 행렬곱셈에서 곱하는 순서에 따라 곱셈의 횟수가 달라지는데 이러한 성질을 이용하여 최소로 곱하는 횟수를 구하는 것이다.
MCM algorithm 설계
위의 예시에서 사용하였던 A, B, C, D 행렬을 다시 가져오자. (A = 20 x 1, B = 1 x 30, C = 30 x 10, D = 10 x 10 )
A를 1번 행렬, B를 2번 행렬, C를 3번 행렬, D를 4번 행렬이라고 할 때,
- M[i][j]를 i번 행렬에서 j번 행렬까지 곱하는 최소 횟수라고 정의하자.
ex) M[1][3]: 1번 행렬 (A)에서 3번 행렬(C) 까지 곱하는 횟수 == A x B x C 를 연산할 때의 최소 횟수
- A i = di-1 x di 이라고 정의하자
ex) A가 1번째 행렬으므로, d0 = 20, d1 = 1
B가 2번째 행렬이므로, d1 = 1, d2 = 30
C가 3번째 행렬이므로, d2 = 30, d3 = 10
D가 4번째 행렬이므로, d3 = 10, d4 = 10
위의 내용으로 다시 M[i][j]을 정의하면
로 정의할 수 있다.
3개 이상의 행렬을 곱할때, 끊어서 곱하겠다는 이야기이다.
무슨 뜻인지 잘 와닿지 않을 수 있으니 예시를 보자
M[1][3] 일때 i = 1, j = 3이므로 k는 1 <= k <= 2, k = 1, 2가 된다.
k = 1 인 경우, M[1][3] = M[1][1] + M[2][3] + d0d1d3
k = 2 인 경우, M[1][3] = M[1][2] + M[3][3] + d0d2d3
두 가지 경우 중 M[1][3]의 값은 위 두 경우 중 작은 값을 넣어주면 되는 것이다.
한 가지 더욱 쉬운 예시는 M[1][2]이다.
M[1][2] 일 때, i = 1, j = 2, k = 1이 된다.
따라서 k = 1인 경우가 최솟값이 된다.
M[1][2] = M[1][1] + M[2][2] + d0d1d2 = 0 + 0 + 20 * 1 * 30 = 600
즉, 행렬이 N개 일때, M[1][N]의 값을 구하는 것이 최종 목표이고 값을 구하기 위해서는 각 구간의 값을 모두 구해서 최솟값을 찾아야 한다.
M[1][4]의 값을 구하려면
- M[1][1] + M[2][4] + d0 x d1 x d4
- M[1][2] + M[3][4] + d0 x d2 x d4
- M[1][3] + M[4][4] + d0 x d3 x d4 중 최솟값을 찾아야 한다.
따라서 M[1][1]~M[1][4]의 값과 M[2][4]~M[4][4]의 값이 필요하다.
M[i][j]의 값은 대각선을 하나씩 증가시키며 아래와 같이 구할 수 있다.
- (i, i) 구하기
(0,0) | 1 | 2 | 3 | 4 |
---|---|---|---|---|
1 | 0 | |||
2 | 0 | |||
3 | 0 | |||
4 | 0 |
- (1, 2) ~ (3, 4)
(0,0) | 1 | 2 | 3 | 4 |
---|---|---|---|---|
1 | 0 | 600 | ||
2 | 0 | 300 | ||
3 | 0 | 3000 | ||
4 | 0 |
- (1, 3) ~ (2, 4)
(0,0) | 1 | 2 | 3 | 4 |
---|---|---|---|---|
1 | 0 | 600 | 500 | |
2 | 0 | 300 | 400 | |
3 | 0 | 3000 | ||
4 | 0 |
- (1, 4)
(0,0) | 1 | 2 | 3 | 4 |
---|---|---|---|---|
1 | 0 | 600 | 500 | 600 |
2 | 0 | 300 | 400 | |
3 | 0 | 3000 | ||
4 | 0 |
M[1][4] = 600이며 우리가 찾고자 하는 값이 나왔다.
S[i][j] 배열
에서 k값을 저장하는 배열이다.
M[i][j]에서 k값을 무엇으로 잡았을 때 가장 최소가 되었는가? 에 대한 정보를 저장하는 배열이다.
S[i][j] 배열
(0,0) | 1 | 2 | 3 | 4 |
---|---|---|---|---|
1 | 0 | 1 | ||
2 | 0 | 2 | ||
3 | 0 | 3 | ||
4 | 0 |
M[i][j] 배열
(0,0) | 1 | 2 | 3 | 4 |
---|---|---|---|---|
1 | 0 | 600 | ||
2 | 0 | 300 | ||
3 | 0 | 3000 | ||
4 | 0 |
d0 = 20, d1 = 1, d2 = 30, d3 = 10, d4 = 10
M[1][3] 일때 i = 1, j = 3이므로 k는 1 <= k <= 2, k = 1, 2가 된다.
k = 1 인 경우, M[1][3] = M[1][1] + M[2][3] + d0d1d3 = 0 + 300 + 20*1*10 = 500
k = 2 인 경우, M[1][3] = M[1][2] + M[3][3] + d0d2d3 = 600 + 0 + 20*30*10 = 6600
여기서 k = 1
(0,0) | 1 | 2 | 3 | 4 |
---|---|---|---|---|
1 | 0 | 1 | 1 | |
2 | 0 | 2 | ||
3 | 0 | 3 | ||
4 | 0 |
M[2][4] 일때 i = 2, j = 4이므로 k는 2 <= k <= 3, k = 3, 4가 된다.
k = 2 인 경우, M[2][4] = M[2][2] + M[3][4] + d1d2d4 = 0 + 3000 + 1*30*10 = 3300
k = 3 인 경우, M[1][3] = M[2][3] + M[4][4] + d1d3d4 = 300+ 0 + 1*10*10 = 400
여기서 k = 3,
(1, 4)는 생략..
(0,0) | 1 | 2 | 3 | 4 |
---|---|---|---|---|
1 | 0 | 1 | 1 | 1 |
2 | 0 | 2 | 3 | |
3 | 0 | 3 | ||
4 | 0 |
여기서 S[1][4] = 1 이므로,
M[1][4] = M[1][1] + M[2][4] = d0d1d4 라는 것을 알 수 있고,
S[2][4] = 3 이므로,
M[2][4] = M[2][3] + M[4][4] = d1d3d4 라는 것을 알 수 있다.
따라서 (A x (B x C )x D))의 순서로 곱했음을 시사한다.
구현
M[i][j]의 점화식을 토대로 작성해보자.
#include <cstdio>
#include <algorithm>
#include <limits.h>
using namespace std;
int main() {
int N;
scanf("%d", &N);
int matrics[1001][2];
int d[2002];
int M[1001][1001];
int S[1001][1001];
for (int i = 0; i < N; i++) {
scanf("%d %d", &matrics[i][0], &matrics[i][1]);
}
// Make d[i] array
d[0] = matrics[0][0];
for (int i = 0; i < N; i++) {
d[i+1] = matrics[i][1];
}
// Memoization
// if (i,i) = 0
for (int i = 1; i <= N; i++)
M[i][i] = 0;
// else
for (int r = 2; r <= N; r++) { // r is chain length
for (int i = 1; i <= N-r+1; i++) {
int j = i+r-1;
M[i][j] = INT_MAX;
for (int k = i; k < j; k++) { // k is diverging point
if (M[i][j] > M[i][k] + M[k+1][j] + d[i-1]*d[k]*d[j]) {
M[i][j] = M[i][k] + M[k+1][j] + d[i-1]*d[k]*d[j];
S[i][j] = k;
}
}
}
}
// Print M, S array
printf("\n");
printf("M Array\n");
for (int i = 1; i <= N; i++) {
for (int j = 1; j <= N; j++) {
printf("%d ", M[i][j]);
}
printf("\n");
}
printf("\n");
printf("S Array\n");
for (int i = 1; i <= N; i++) {
for (int j = 1; j <= N; j++) {
printf("%d ", S[i][j]);
}
printf("\n");
}
return 0;
}
'Algorithm > 다이나믹 프로그래밍' 카테고리의 다른 글
Dynamic Programming (0) | 2021.07.06 |
---|---|
[DP] Knapsack problem (0) | 2021.07.06 |
[DP] LCS (Longest Common Subsequence) (0) | 2020.04.18 |