diff --git a/divide_and_conquer/strassen_matrix_multiplication.py b/divide_and_conquer/strassen_matrix_multiplication.py index f529a255d2ef..a69589742700 100644 --- a/divide_and_conquer/strassen_matrix_multiplication.py +++ b/divide_and_conquer/strassen_matrix_multiplication.py @@ -73,8 +73,28 @@ def print_matrix(matrix: list) -> None: def actual_strassen(matrix_a: list, matrix_b: list) -> list: """ - Recursive function to calculate the product of two matrices, using the Strassen - Algorithm. It only supports square matrices of any size that is a power of 2. + Recursive function to calculate the product of two matrices using the Strassen + Algorithm. + + This is the core recursive implementation that only supports square matrices + of size that is a power of 2 (e.g., 2x2, 4x4, 8x8, 16x16, etc.). + + The algorithm works by: + 1. Base case: For 2x2 matrices, use standard multiplication + 2. Recursive case: + - Split both matrices into 4 quadrants + - Compute 7 products using the formulas above + - Combine the 7 products to get the final result + + Args: + matrix_a: Square matrix with dimensions as power of 2 + matrix_b: Square matrix with dimensions as power of 2 + + Returns: + Product matrix + + Raises: + Exception: If matrices are not square or dimensions are not power of 2 """ if matrix_dimensions(matrix_a) == (2, 2): return default_matrix_multiplication(matrix_a, matrix_b) @@ -106,6 +126,42 @@ def actual_strassen(matrix_a: list, matrix_b: list) -> list: def strassen(matrix1: list, matrix2: list) -> list: """ + Multiplies two matrices using the Strassen algorithm for improved time complexity. + + The Strassen algorithm reduces the complexity of matrix multiplication from + O(n³) to O(n^2.807) by reducing the number of recursive matrix multiplications + from 8 to 7. While the asymptotic complexity is better, the actual performance + improvement is typically seen only for large matrices due to higher constant + factors and additional overhead. + + Time Complexity: O(n^2.807) - Strassen vs O(n³) for standard multiplication + Space Complexity: O(n²) for storing the result matrix + + The algorithm works by recursively dividing matrices into 2x2 submatrices and + computing 7 products (P1-P7) instead of 8: + P1 = A * (F - H) + P2 = (A + B) * H + P3 = (C + D) * E + P4 = D * (G - E) + P5 = (A + D) * (E + H) + P6 = (B - D) * (G + H) + P7 = (A - C) * (E + F) + + Then combines these products to get the final result matrix. + + Note: This implementation requires input matrices to have dimensions that are + powers of 2. The function automatically pads matrices with zeros if needed. + + Args: + matrix1: First matrix (m x n) + matrix2: Second matrix (n x p) + + Returns: + Result matrix (m x p) + + Raises: + Exception: If matrix dimensions are incompatible for multiplication + >>> strassen([[2,1,3],[3,4,6],[1,4,2],[7,6,7]], [[4,2,3,4],[2,1,1,1],[8,6,4,2]]) [[34, 23, 19, 15], [68, 46, 37, 28], [28, 18, 15, 12], [96, 62, 55, 48]] >>> strassen([[3,7,5,6,9],[1,5,3,7,8],[1,4,4,5,7]], [[2,4],[5,2],[1,7],[5,5],[7,8]])