Skip to content

Commit 51a69aa

Browse files
committed
Implement Strassen Algorithm
1 parent 2478603 commit 51a69aa

1 file changed

Lines changed: 182 additions & 0 deletions

File tree

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import java.io.BufferedReader;
2+
import java.io.IOException;
3+
import java.io.InputStreamReader;
4+
import java.util.Arrays;
5+
import java.util.stream.Collectors;
6+
import java.util.stream.IntStream;
7+
8+
/**
9+
* Class providing the Strassen algorithm as a divide-and-conquer
10+
* functionality to multiply two square matrices.
11+
*/
12+
public class StrassenAlgorithm {
13+
14+
/**
15+
* Adds matrix A to B.
16+
*
17+
* @param A The first matrix to add.
18+
* @param B The second matrix to add.
19+
* @return The sum of A and B.
20+
*/
21+
public static int[][] add(int[][] A, int[][] B) {
22+
int n = A.length;
23+
int[][] C = new int[n][n];
24+
for (int row = 0; row < n; row++) {
25+
final int i = row;
26+
C[i] = IntStream.range(0, n).map(j -> A[i][j] + B[i][j]).toArray();
27+
}
28+
return C;
29+
}
30+
31+
/**
32+
* Subtracts matrix B from A.
33+
*
34+
* @param A The matrix to subtract from.
35+
* @param B The matrix to subtract.
36+
* @return The difference between A and B.
37+
*/
38+
public static int[][] sub(int[][] A, int[][] B) {
39+
int n = A.length;
40+
int[][] C = new int[n][n];
41+
for (int row = 0; row < n; row++) {
42+
final int i = row;
43+
C[i] = IntStream.range(0, n).map(j -> A[i][j] - B[i][j]).toArray();
44+
}
45+
return C;
46+
}
47+
48+
/**
49+
* Multiply two matrices A and B using the Strassen algorithm.
50+
*
51+
* @param A The first matrix.
52+
* @param B The second matrix.
53+
* @return The product of A and B.
54+
*/
55+
public static int[][] multiply(int[][] A, int[][] B) {
56+
// Size of the matrices
57+
int n = A.length;
58+
59+
// Return matrix initialization
60+
int[][] R = new int[n][n];
61+
62+
if (n == 1) {
63+
// Stop recursion, base case with matrix sizes 1
64+
R[0][0] = A[0][0] * B[0][0];
65+
} else if (n % 2 != 0) {
66+
// Odd matrix size, expand matrices to be of even size
67+
int[][] Anew = new int[n + 1][n + 1];
68+
int[][] Bnew = new int[n + 1][n + 1];
69+
conquer(A, Anew, 0, 0);
70+
conquer(B, Bnew, 0, 0);
71+
72+
// Multiply matrices of even size
73+
int[][] Cnew = multiply(Anew, Bnew);
74+
75+
// Extract relevant values
76+
for (int i = 0; i < n; i++) {
77+
System.arraycopy(Cnew[i], 0, R[i], 0, n);
78+
}
79+
} else {
80+
// Divide matrix A into 4 quarters
81+
int[][] A11 = divide(A, 0, 0, n / 2);
82+
int[][] A12 = divide(A, 0, n / 2, n / 2);
83+
int[][] A21 = divide(A, n / 2, 0, n / 2);
84+
int[][] A22 = divide(A, n / 2, n / 2, n / 2);
85+
86+
// Divide matrix B into 4 quarters
87+
int[][] B11 = divide(B, 0, 0, n / 2);
88+
int[][] B12 = divide(B, 0, n / 2, n / 2);
89+
int[][] B21 = divide(B, n / 2, 0, n / 2);
90+
int[][] B22 = divide(B, n / 2, n / 2, n / 2);
91+
92+
// Apply all the calculation steps from the algorithm itself
93+
int[][] M1 = multiply(add(A11, A22), add(B11, B22));
94+
int[][] M2 = multiply(add(A21, A22), B11);
95+
int[][] M3 = multiply(A11, sub(B12, B22));
96+
int[][] M4 = multiply(A22, sub(B21, B11));
97+
int[][] M5 = multiply(add(A11, A12), B22);
98+
int[][] M6 = multiply(sub(A21, A11), add(B11, B12));
99+
int[][] M7 = multiply(sub(A12, A22), add(B21, B22));
100+
int[][] C11 = add(sub(add(M1, M4), M5), M7);
101+
int[][] C12 = add(M3, M5);
102+
int[][] C21 = add(M2, M4);
103+
int[][] C22 = add(sub(add(M1, M3), M2), M6);
104+
105+
// Join matrices and return result
106+
conquer(C11, R, 0, 0);
107+
conquer(C12, R, 0, n / 2);
108+
conquer(C21, R, n / 2, 0);
109+
conquer(C22, R, n / 2, n / 2);
110+
}
111+
return R;
112+
}
113+
114+
/**
115+
* Extract a square sub-matrix from matrix A.
116+
*
117+
* @param A The matrix to extract a sub-matrix from.
118+
* @param row The start row index.
119+
* @param col The start column index.
120+
* @param n The size of the sub-matrix to extract.
121+
* @return The extracted sub-matrix of A.
122+
*/
123+
private static int[][] divide(int[][] A, int row, int col, int n) {
124+
int[][] R = new int[n][n];
125+
for (int Rrow = 0, Arow = row; Rrow < n; Rrow++, Arow++) {
126+
System.arraycopy(A[Arow], col, R[Rrow], 0, n);
127+
}
128+
return R;
129+
}
130+
131+
/**
132+
* Insert a matrix S into A.
133+
*
134+
* @param S The matrix to insert.
135+
* @param A The matrix to insert S into.
136+
* @param row The start row index.
137+
* @param col The start column index.
138+
*/
139+
private static void conquer(int[][] S, int[][] A, int row, int col) {
140+
for (int Srow = 0, Arow = row; Srow < S.length; Srow++, Arow++) {
141+
System.arraycopy(S[Srow], 0, A[Arow], col, S.length);
142+
}
143+
}
144+
145+
public static void main(String[] args) throws IOException {
146+
System.out.println("Starting the Strassen Algorithm...");
147+
148+
try (InputStreamReader isr = new InputStreamReader(System.in);
149+
BufferedReader in = new BufferedReader(isr)) {
150+
System.out.print("Enter number of rows/columns of the input matrices: ");
151+
int n = Integer.parseInt(in.readLine());
152+
153+
int[][] A = new int[n][n];
154+
int[][] B = new int[n][n];
155+
156+
System.out.println("Enter first matrix elements:");
157+
158+
for (int i = 0; i < n; i++) {
159+
String[] line = in.readLine().split(" ");
160+
A[i] = Arrays.stream(line).mapToInt(Integer::parseInt).toArray();
161+
}
162+
163+
System.out.println("Enter second matrix elements:");
164+
165+
for (int i = 0; i < n; i++) {
166+
String[] line = in.readLine().split(" ");
167+
B[i] = Arrays.stream(line).mapToInt(Integer::parseInt).toArray();
168+
}
169+
170+
int[][] C = multiply(A, B);
171+
172+
System.out.println("Product of the matrices:");
173+
174+
for (int i = 0; i < n; i++) {
175+
System.out.println(Arrays.stream(C[i])
176+
.mapToObj(e -> e + " ")
177+
.collect(Collectors.joining()));
178+
}
179+
}
180+
}
181+
182+
}

0 commit comments

Comments
 (0)