Sympy dot product does not read dimensions correctly
Question:
I want to multiply two matrices in sympy
. However, the .dot()
function does not seem to work properly, as when using the transpose matrix so that the dimensions match, the same error appears as when not using it.
Code to reproduce:
from sympy import *
x1, x2, x3 = symbols('x1 x2 x3')
x = Matrix([x1, x2, x3])
m = Matrix([1, 2, 3])
xm = (x-m)
print(xm.shape)
xmt = xm.T
print(xmt.shape)
c = Matrix([[1, 2, 3], [2, 1, 4], [3, 4, 5]])
print(c.shape)
Output:
(3, 1)
(1, 3)
(3, 3)
Now, whether you try
xm.dot(c)
or
xmt.dot(c)
Both commands result in the same error:
ShapeError: Matrix size mismatch: (3, 1) * (3, 3).
which shouldn’t be possible, since xm
matrix is indeed of dimensions (3,1) but xmt
is of dimensions (1,3).
Am I doing something wrong?
Answers:
this seems like a bug in sympy, but using dot
is already deprecated, use the @
operator for matrix dot product and it will work
result = xmt @ c
you could post an issue on sympy’s github about dot
, but you shouldn’t be using it anyway, as it is deprecated.
Edit: according to @hpaulj the correct function to call in sympy would be xmt.multiply(c)
if the @
operator is to be avoided.
I want to multiply two matrices in sympy
. However, the .dot()
function does not seem to work properly, as when using the transpose matrix so that the dimensions match, the same error appears as when not using it.
Code to reproduce:
from sympy import *
x1, x2, x3 = symbols('x1 x2 x3')
x = Matrix([x1, x2, x3])
m = Matrix([1, 2, 3])
xm = (x-m)
print(xm.shape)
xmt = xm.T
print(xmt.shape)
c = Matrix([[1, 2, 3], [2, 1, 4], [3, 4, 5]])
print(c.shape)
Output:
(3, 1)
(1, 3)
(3, 3)
Now, whether you try
xm.dot(c)
or
xmt.dot(c)
Both commands result in the same error:
ShapeError: Matrix size mismatch: (3, 1) * (3, 3).
which shouldn’t be possible, since xm
matrix is indeed of dimensions (3,1) but xmt
is of dimensions (1,3).
Am I doing something wrong?
this seems like a bug in sympy, but using dot
is already deprecated, use the @
operator for matrix dot product and it will work
result = xmt @ c
you could post an issue on sympy’s github about dot
, but you shouldn’t be using it anyway, as it is deprecated.
Edit: according to @hpaulj the correct function to call in sympy would be xmt.multiply(c)
if the @
operator is to be avoided.