Multiply tensors containing matrices, following matrix multiplication rule
Question:
Say I have a tensor, where A
, B
, C
, and D
are all 2×2 matrices:
M = [[A, B],
[C, D]]
How do I get to the power of n
, for example with n=2
, with Python or MATLAB
M^2 = [[A@A + B@C, A@B + B@D],
[C@A + D@C, C@B + D@D]]
Here the power just follows the normal matrix multiplication rule; it’s just that the elements are matrices themselves. I tried matmul
, matrix_power
, and pagemtimes
, but nothing works.
Answers:
Probably not the most efficient, but here’s a manual solution:
M = np.random.randint(0, 10, (2, 2, 2, 2))
def matmatmul(a, b):
output = np.zeros((a.shape[0], b.shape[1]), dtype = object)
for i in range(output.shape[0]):
for j in range(output.shape[1]):
row = a[i]
col = b[:, j]
output[i,j] = sum([r @ c for r,c in zip(row, col)])
return output
def matmatpow(a, n):
if n == 1:
return a
else:
output = matmatmul(a, a)
for i in range(2, n):
output = matmatmul(output, a)
return output
M2 = matmatpow(M, 2)
print(M2)
[[A, B], [C, D]] = M
assert np.all(M2[0,0] == A@A + B@C)
assert np.all(M2[0,1] == A@B + B@D)
assert np.all(M2[1,0] == C@A + D@C)
assert np.all(M2[1,1] == C@B + D@D)
Defining a set of (2,2) arrays, and their composite:
In [45]: A,B,C,D = [np.arange(i,i+4).reshape(2,2) for i in range(4)]
In [46]: M=np.array([[A,B],[C,D]])
Your desired M^2
array:
In [47]: np.array([[A@A + B@C, A@B + B@D],
...: [C@A + D@C, C@B + D@D]])
Out[47]:
array([[[[12, 16],
[28, 40]],
[[16, 20],
[40, 52]]],
[[[28, 40],
[44, 64]],
[[40, 52],
[64, 84]]]])
The same thing using einsum
. In this j
and l
are the sum-of-products dimensions:
In [48]: np.einsum('ijkl,jmln->imkn',M,M)
Out[48]:
array([[[[12, 16],
[28, 40]],
[[16, 20],
[40, 52]]],
[[[28, 40],
[44, 64]],
[[40, 52],
[64, 84]]]])
matmul
is the equivalent of ‘ijkl,ijlm->ijkm’, where ij
are batch
dimensions, and l
is the sum-of-products. Often an einsum
can be reproduced with some reshape and generalized transposing. But I’ll leave that for someone else to explore.
Playing around with the einsum indices and transposing and reshaping the arrays, I can get the equivalent of:
In [56]: np.matmul(M.transpose(0,2,1,3).reshape(4,4),M.transpose(0,2,1,3).reshape(4,4))
Out[56]:
array([[12, 16, 16, 20],
[28, 40, 40, 52],
[28, 40, 40, 52],
[44, 64, 64, 84]])
which with a bit more massaging becomes the desired (4,4,4,4)
In [57]: np.matmul(M.transpose(0,2,1,3).reshape(4,4),M.transpose(0,2,1,3).reshape(4,4)).reshape(2,2,2,2).transpose(0,2,1,3)
Out[57]:
array([[[[12, 16],
[28, 40]],
[[16, 20],
[40, 52]]],
[[[28, 40],
[44, 64]],
[[40, 52],
[64, 84]]]])
You are just computing the normal matrix product of the 4×4 block matrix created by joining the smaller matrices A
through D
.
In MATLAB, your expected result using some arbitrary matrices:
A = [1, 2
3, 4];
B = [5, 6
7, 8];
C = [ 9, 10
11, 12];
D = [13, 14
15, 16];
res = [A*A + B*C, A*B + B*D
C*A + D*C, C*B + D*D]
res =
118 132 174 188
166 188 254 276
310 356 494 540
358 412 574 628
The 4×4 block matrix, and its square:
M = [A, B
C, D];
res2 = M^2
res2 =
118 132 174 188
166 188 254 276
310 356 494 540
358 412 574 628
Say I have a tensor, where A
, B
, C
, and D
are all 2×2 matrices:
M = [[A, B],
[C, D]]
How do I get to the power of n
, for example with n=2
, with Python or MATLAB
M^2 = [[A@A + B@C, A@B + B@D],
[C@A + D@C, C@B + D@D]]
Here the power just follows the normal matrix multiplication rule; it’s just that the elements are matrices themselves. I tried matmul
, matrix_power
, and pagemtimes
, but nothing works.
Probably not the most efficient, but here’s a manual solution:
M = np.random.randint(0, 10, (2, 2, 2, 2))
def matmatmul(a, b):
output = np.zeros((a.shape[0], b.shape[1]), dtype = object)
for i in range(output.shape[0]):
for j in range(output.shape[1]):
row = a[i]
col = b[:, j]
output[i,j] = sum([r @ c for r,c in zip(row, col)])
return output
def matmatpow(a, n):
if n == 1:
return a
else:
output = matmatmul(a, a)
for i in range(2, n):
output = matmatmul(output, a)
return output
M2 = matmatpow(M, 2)
print(M2)
[[A, B], [C, D]] = M
assert np.all(M2[0,0] == A@A + B@C)
assert np.all(M2[0,1] == A@B + B@D)
assert np.all(M2[1,0] == C@A + D@C)
assert np.all(M2[1,1] == C@B + D@D)
Defining a set of (2,2) arrays, and their composite:
In [45]: A,B,C,D = [np.arange(i,i+4).reshape(2,2) for i in range(4)]
In [46]: M=np.array([[A,B],[C,D]])
Your desired M^2
array:
In [47]: np.array([[A@A + B@C, A@B + B@D],
...: [C@A + D@C, C@B + D@D]])
Out[47]:
array([[[[12, 16],
[28, 40]],
[[16, 20],
[40, 52]]],
[[[28, 40],
[44, 64]],
[[40, 52],
[64, 84]]]])
The same thing using einsum
. In this j
and l
are the sum-of-products dimensions:
In [48]: np.einsum('ijkl,jmln->imkn',M,M)
Out[48]:
array([[[[12, 16],
[28, 40]],
[[16, 20],
[40, 52]]],
[[[28, 40],
[44, 64]],
[[40, 52],
[64, 84]]]])
matmul
is the equivalent of ‘ijkl,ijlm->ijkm’, where ij
are batch
dimensions, and l
is the sum-of-products. Often an einsum
can be reproduced with some reshape and generalized transposing. But I’ll leave that for someone else to explore.
Playing around with the einsum indices and transposing and reshaping the arrays, I can get the equivalent of:
In [56]: np.matmul(M.transpose(0,2,1,3).reshape(4,4),M.transpose(0,2,1,3).reshape(4,4))
Out[56]:
array([[12, 16, 16, 20],
[28, 40, 40, 52],
[28, 40, 40, 52],
[44, 64, 64, 84]])
which with a bit more massaging becomes the desired (4,4,4,4)
In [57]: np.matmul(M.transpose(0,2,1,3).reshape(4,4),M.transpose(0,2,1,3).reshape(4,4)).reshape(2,2,2,2).transpose(0,2,1,3)
Out[57]:
array([[[[12, 16],
[28, 40]],
[[16, 20],
[40, 52]]],
[[[28, 40],
[44, 64]],
[[40, 52],
[64, 84]]]])
You are just computing the normal matrix product of the 4×4 block matrix created by joining the smaller matrices A
through D
.
In MATLAB, your expected result using some arbitrary matrices:
A = [1, 2
3, 4];
B = [5, 6
7, 8];
C = [ 9, 10
11, 12];
D = [13, 14
15, 16];
res = [A*A + B*C, A*B + B*D
C*A + D*C, C*B + D*D]
res =
118 132 174 188
166 188 254 276
310 356 494 540
358 412 574 628
The 4×4 block matrix, and its square:
M = [A, B
C, D];
res2 = M^2
res2 =
118 132 174 188
166 188 254 276
310 356 494 540
358 412 574 628