Python: Divide and Conquer Recursive Matrix Multiplication

Question:

I’m trying to implement the divide and conquer matrix multiplication (8 recursion version not Strassen). I thought I had it figured out but it is producing weird output with too many nested lists and the wrong values. I suspect the problem is how I’m summing the 8 recursions but Im not sure.

def multiMatrix(x,y):
    n = len(x)
    if n == 1:
        return x[0][0] * y[0][0]
    else:
        a = [[col for col in row[:len(row)/2]] for row in x[:len(x)/2]]
        b = [[col for col in row[len(row)/2:]] for row in x[:len(x)/2]]
        c = [[col for col in row[:len(row)/2]] for row in x[len(x)/2:]]
        d = [[col for col in row[len(row)/2:]] for row in x[len(x)/2:]]
        e = [[col for col in row[:len(row)/2]] for row in y[:len(y)/2]]
        f = [[col for col in row[len(row)/2:]] for row in y[:len(y)/2]]
        g = [[col for col in row[:len(row)/2]] for row in y[len(y)/2:]]
        h = [[col for col in row[len(row)/2:]] for row in y[len(y)/2:]]
        ae = multiMatrix(a,e)
        bg = multiMatrix(b,g)
        af = multiMatrix(a,f)
        bh = multiMatrix(b,h)
        ce = multiMatrix(c,e)
        dg = multiMatrix(d,g)
        cf = multiMatrix(c,f)
        dh = multiMatrix(d,h)

        c = [[ae+bg,af+bh],[ce+dg,cf+dh]]

        return c


a = [
    [1,2,3,4],
    [5,6,7,8],
    [9,10,11,12],
    [13,14,15,16]
    ]
b = [
    [1,2,3,4],
    [5,6,7,8],
    [9,10,11,12],
    [13,14,15,16]
    ]

print multiMatrix(a,b)
Asked By: Solomon Bothwell

||

Answers:

Your suspicion is correct, your matrices are still lists, so adding them will just make a longer list.

Try using something like this

def matrix_add(a, b):
    return [[ea+eb for ea, eb in zip(*rowpair)] for rowpair in zip(a, b)]

in your code.

To join blocks:

def join_horiz(a, b):
    return [rowa + rowb for rowa, rowb in zip(a,b)]

def join_vert(a, b):
    return a+b

Finally, to make it all work together I think you have to change your special case for 1 to

return [[x[0][0] * y[0][0]]]

Edit:

I just realised that this will only work for power-of-two dimensions. Otherwise you will have to deal with non-square matrices and it will happen that x is 1 x something and your special case won’t work. So you’ll also have to check for len(x[0]) (if n > 0).

Answered By: Paul Panzer
def join_horiz(a, b):

    return [rowa + rowb for rowa, rowb in zip(a,b)]

def MatAdd(A,B):

    resultant = [[0 for i in range(len(A))]  for j in range(len(A))]
    for i in range(len(A)):
        for j in range(len(A)):
            resultant[i][j] = A[i][j] + B[i][j]
    return resultant

def createSubmatrices(A,starting_index,rows,columns):

    resultant = [[0 for i in range(rows)]  for j in range(columns)]
    for i in range(rows):
        for j in range(columns):
            resultant[i][j] = A[starting_index[0] + i][starting_index[1] + j]
    return resultant

def MatMulRecursive(A,B,n):

    if(n==1):
        return [[A[0][0]*B[0][0]]]
    else:
        A11 = createSubmatrices(A, (0,0), n//2, n//2)
        A12 = createSubmatrices(A, (0,n//2), n//2, n//2)
        A21 = createSubmatrices(A, (n//2,0), n//2, n//2)
        A22 = createSubmatrices(A, (n//2,n//2), n//2, n//2)
        
        B11 = createSubmatrices(B, (0,0), n//2, n//2)
        B12 = createSubmatrices(B, (0,n//2), n//2, n//2)
        B21 = createSubmatrices(B, (n//2,0), n//2, n//2)
        B22 = createSubmatrices(B, (n//2,n//2), n//2, n//2)
        
        C11 = list(MatAdd(MatMulRecursive(A11, B11, n//2) , MatMulRecursive(A12, B21, n//2)))
        C12 = list(MatAdd(MatMulRecursive(A11, B12, n//2) , MatMulRecursive(A12, B22, n//2)))
        C21 = list(MatAdd(MatMulRecursive(A21, B11, n//2) , MatMulRecursive(A22, B21, n//2)))
        C22 = list(MatAdd(MatMulRecursive(A21, B12, n//2) , MatMulRecursive(A22, B22, n//2)))
        
        return join_horiz(C11, C12) + join_horiz(C21, C22)
    
    
A = [[1,1,1,1], [1,5,5,1], [1,7,7,1], [3,3,3,2]]


B = [[2,2,2,2], [2,2,2,2], [2,2,2,2], [2,2,2,2]]

C = MatMulRecursive(A, B, 4)

print(C)
Answered By: Syed Hashir

enter image description here

if we give recursive function only two matrices then the code will be more clean

Answered By: M.KabirAhmad