Skip to content

Commit 3e55d21

Browse files
committed
feat: Strassen's matrix multiplication algorithm added
1 parent 788d95b commit 3e55d21

1 file changed

Lines changed: 132 additions & 0 deletions

File tree

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from typing import List
2+
3+
Matrix = List[List[int]]
4+
5+
def add(A: Matrix, B: Matrix) -> Matrix:
6+
n = len(A)
7+
return [[A[i][j] + B[i][j] for j in range(n)] for i in range(n)]
8+
9+
def sub(A: Matrix, B: Matrix) -> Matrix:
10+
n = len(A)
11+
return [[A[i][j] - B[i][j] for j in range(n)] for i in range(n)]
12+
13+
def naive_mul(A: Matrix, B: Matrix) -> Matrix:
14+
n = len(A)
15+
C = [[0]*n for _ in range(n)]
16+
for i in range(n):
17+
ai = A[i]
18+
ci = C[i]
19+
for k in range(n):
20+
a_ik = ai[k]
21+
bk = B[k]
22+
for j in range(n):
23+
ci[j] += a_ik * bk[j]
24+
return C
25+
26+
def next_power_of_two(n: int) -> int:
27+
p = 1
28+
while p < n:
29+
p <<= 1
30+
return p
31+
32+
def pad_matrix(A: Matrix, size: int) -> Matrix:
33+
n = len(A)
34+
padded = [[0]*size for _ in range(size)]
35+
for i in range(n):
36+
for j in range(len(A[0])):
37+
padded[i][j] = A[i][j]
38+
return padded
39+
40+
def unpad_matrix(A: Matrix, rows: int, cols: int) -> Matrix:
41+
return [row[:cols] for row in A[:rows]]
42+
43+
def split(A: Matrix) -> tuple:
44+
n = len(A)
45+
mid = n // 2
46+
A11 = [[A[i][j] for j in range(mid)] for i in range(mid)]
47+
A12 = [[A[i][j] for j in range(mid, n)] for i in range(mid)]
48+
A21 = [[A[i][j] for j in range(mid)] for i in range(mid, n)]
49+
A22 = [[A[i][j] for j in range(mid, n)] for i in range(mid, n)]
50+
return A11, A12, A21, A22
51+
52+
def join(C11: Matrix, C12: Matrix, C21: Matrix, C22: Matrix) -> Matrix:
53+
n2 = len(C11)
54+
n = n2 * 2
55+
C = [[0]*n for _ in range(n)]
56+
for i in range(n2):
57+
for j in range(n2):
58+
C[i][j] = C11[i][j]
59+
C[i][j + n2] = C12[i][j]
60+
C[i + n2][j] = C21[i][j]
61+
C[i + n2][j + n2] = C22[i][j]
62+
return C
63+
64+
def strassen(A: Matrix, B: Matrix, threshold: int = 64) -> Matrix:
65+
"""
66+
Multiply square matrices A and B using Strassen algorithm.
67+
threshold: below this size, uses naive multiplication (tweakable).
68+
"""
69+
assert len(A) == len(A[0]) == len(B) == len(B[0]), "Only square matrices supported in this implementation"
70+
71+
n_orig = len(A)
72+
if n_orig == 0:
73+
return []
74+
75+
m = next_power_of_two(n_orig)
76+
if m != n_orig:
77+
A_pad = pad_matrix(A, m)
78+
B_pad = pad_matrix(B, m)
79+
else:
80+
A_pad, B_pad = A, B
81+
82+
C_pad = _strassen_recursive(A_pad, B_pad, threshold)
83+
84+
C = unpad_matrix(C_pad, n_orig, n_orig)
85+
return C
86+
87+
def _strassen_recursive(A: Matrix, B: Matrix, threshold: int) -> Matrix:
88+
n = len(A)
89+
if n <= threshold:
90+
return naive_mul(A, B)
91+
if n == 1:
92+
return [[A[0][0] * B[0][0]]]
93+
94+
A11, A12, A21, A22 = split(A)
95+
B11, B12, B21, B22 = split(B)
96+
97+
M1 = _strassen_recursive(add(A11, A22), add(B11, B22), threshold)
98+
M2 = _strassen_recursive(add(A21, A22), B11, threshold)
99+
M3 = _strassen_recursive(A11, sub(B12, B22), threshold)
100+
M4 = _strassen_recursive(A22, sub(B21, B11), threshold)
101+
M5 = _strassen_recursive(add(A11, A12), B22, threshold)
102+
M6 = _strassen_recursive(sub(A21, A11), add(B11, B12), threshold)
103+
M7 = _strassen_recursive(sub(A12, A22), add(B21, B22), threshold)
104+
105+
C11 = add(sub(add(M1, M4), M5), M7)
106+
C12 = add(M3, M5)
107+
C21 = add(M2, M4)
108+
C22 = add(sub(add(M1, M3), M2), M6)
109+
110+
return join(C11, C12, C21, C22)
111+
112+
if __name__ == "__main__":
113+
A = [
114+
[1, 2, 3],
115+
[4, 5, 6],
116+
[7, 8, 9]
117+
]
118+
B = [
119+
[9, 8, 7],
120+
[6, 5, 4],
121+
[3, 2, 1]
122+
]
123+
124+
C = strassen(A, B, threshold=1)
125+
print("A * B =")
126+
for row in C:
127+
print(row)
128+
129+
# verify against naive
130+
expected = naive_mul(A, B)
131+
assert C == expected, "Strassen result differs from naive multiplication!"
132+
print("Verified: result matches naive multiplication.")

0 commit comments

Comments
 (0)