|
| 1 | +from typing import List |
| 2 | + |
| 3 | +Matrix = List[List[int]] |
| 4 | + |
| 5 | +def add(A: Matrix, B: Matrix) -> Matrix: |
| 6 | + n = len(A) |
| 7 | + return [[A[i][j] + B[i][j] for j in range(n)] for i in range(n)] |
| 8 | + |
| 9 | +def sub(A: Matrix, B: Matrix) -> Matrix: |
| 10 | + n = len(A) |
| 11 | + return [[A[i][j] - B[i][j] for j in range(n)] for i in range(n)] |
| 12 | + |
| 13 | +def naive_mul(A: Matrix, B: Matrix) -> Matrix: |
| 14 | + n = len(A) |
| 15 | + C = [[0]*n for _ in range(n)] |
| 16 | + for i in range(n): |
| 17 | + ai = A[i] |
| 18 | + ci = C[i] |
| 19 | + for k in range(n): |
| 20 | + a_ik = ai[k] |
| 21 | + bk = B[k] |
| 22 | + for j in range(n): |
| 23 | + ci[j] += a_ik * bk[j] |
| 24 | + return C |
| 25 | + |
| 26 | +def next_power_of_two(n: int) -> int: |
| 27 | + p = 1 |
| 28 | + while p < n: |
| 29 | + p <<= 1 |
| 30 | + return p |
| 31 | + |
| 32 | +def pad_matrix(A: Matrix, size: int) -> Matrix: |
| 33 | + n = len(A) |
| 34 | + padded = [[0]*size for _ in range(size)] |
| 35 | + for i in range(n): |
| 36 | + for j in range(len(A[0])): |
| 37 | + padded[i][j] = A[i][j] |
| 38 | + return padded |
| 39 | + |
| 40 | +def unpad_matrix(A: Matrix, rows: int, cols: int) -> Matrix: |
| 41 | + return [row[:cols] for row in A[:rows]] |
| 42 | + |
| 43 | +def split(A: Matrix) -> tuple: |
| 44 | + n = len(A) |
| 45 | + mid = n // 2 |
| 46 | + A11 = [[A[i][j] for j in range(mid)] for i in range(mid)] |
| 47 | + A12 = [[A[i][j] for j in range(mid, n)] for i in range(mid)] |
| 48 | + A21 = [[A[i][j] for j in range(mid)] for i in range(mid, n)] |
| 49 | + A22 = [[A[i][j] for j in range(mid, n)] for i in range(mid, n)] |
| 50 | + return A11, A12, A21, A22 |
| 51 | + |
| 52 | +def join(C11: Matrix, C12: Matrix, C21: Matrix, C22: Matrix) -> Matrix: |
| 53 | + n2 = len(C11) |
| 54 | + n = n2 * 2 |
| 55 | + C = [[0]*n for _ in range(n)] |
| 56 | + for i in range(n2): |
| 57 | + for j in range(n2): |
| 58 | + C[i][j] = C11[i][j] |
| 59 | + C[i][j + n2] = C12[i][j] |
| 60 | + C[i + n2][j] = C21[i][j] |
| 61 | + C[i + n2][j + n2] = C22[i][j] |
| 62 | + return C |
| 63 | + |
| 64 | +def strassen(A: Matrix, B: Matrix, threshold: int = 64) -> Matrix: |
| 65 | + """ |
| 66 | + Multiply square matrices A and B using Strassen algorithm. |
| 67 | + threshold: below this size, uses naive multiplication (tweakable). |
| 68 | + """ |
| 69 | + assert len(A) == len(A[0]) == len(B) == len(B[0]), "Only square matrices supported in this implementation" |
| 70 | + |
| 71 | + n_orig = len(A) |
| 72 | + if n_orig == 0: |
| 73 | + return [] |
| 74 | + |
| 75 | + m = next_power_of_two(n_orig) |
| 76 | + if m != n_orig: |
| 77 | + A_pad = pad_matrix(A, m) |
| 78 | + B_pad = pad_matrix(B, m) |
| 79 | + else: |
| 80 | + A_pad, B_pad = A, B |
| 81 | + |
| 82 | + C_pad = _strassen_recursive(A_pad, B_pad, threshold) |
| 83 | + |
| 84 | + C = unpad_matrix(C_pad, n_orig, n_orig) |
| 85 | + return C |
| 86 | + |
| 87 | +def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix: |
| 88 | + n = len(A) |
| 89 | + if n <= threshold: |
| 90 | + return naive_mul(A, B) |
| 91 | + if n == 1: |
| 92 | + return [[A[0][0] * B[0][0]]] |
| 93 | + |
| 94 | + A11, A12, A21, A22 = split(A) |
| 95 | + B11, B12, B21, B22 = split(B) |
| 96 | + |
| 97 | + M1 = _strassen_recursive(add(A11, A22), add(B11, B22), threshold) |
| 98 | + M2 = _strassen_recursive(add(A21, A22), B11, threshold) |
| 99 | + M3 = _strassen_recursive(A11, sub(B12, B22), threshold) |
| 100 | + M4 = _strassen_recursive(A22, sub(B21, B11), threshold) |
| 101 | + M5 = _strassen_recursive(add(A11, A12), B22, threshold) |
| 102 | + M6 = _strassen_recursive(sub(A21, A11), add(B11, B12), threshold) |
| 103 | + M7 = _strassen_recursive(sub(A12, A22), add(B21, B22), threshold) |
| 104 | + |
| 105 | + C11 = add(sub(add(M1, M4), M5), M7) |
| 106 | + C12 = add(M3, M5) |
| 107 | + C21 = add(M2, M4) |
| 108 | + C22 = add(sub(add(M1, M3), M2), M6) |
| 109 | + |
| 110 | + return join(C11, C12, C21, C22) |
| 111 | + |
| 112 | +if __name__ == "__main__": |
| 113 | + A = [ |
| 114 | + [1, 2, 3], |
| 115 | + [4, 5, 6], |
| 116 | + [7, 8, 9] |
| 117 | + ] |
| 118 | + B = [ |
| 119 | + [9, 8, 7], |
| 120 | + [6, 5, 4], |
| 121 | + [3, 2, 1] |
| 122 | + ] |
| 123 | + |
| 124 | + C = strassen(A, B, threshold=1) |
| 125 | + print("A * B =") |
| 126 | + for row in C: |
| 127 | + print(row) |
| 128 | + |
| 129 | + # verify against naive |
| 130 | + expected = naive_mul(A, B) |
| 131 | + assert C == expected, "Strassen result differs from naive multiplication!" |
| 132 | + print("Verified: result matches naive multiplication.") |
0 commit comments