Efficient way of computing `np.diagonal(np.dot(A, B), axis1=1, axis2=2)` using Numpy

Question:

I have a numpy array A of shape (n, m, k) and B of shape (k, m). I’m wondering if there’s a more efficient way to perform the following operation:

np.diagonal(np.dot(A, B), axis1=1, axis2=2)

since it’s performing a lot of computations I don’t need in the np.dot (I only need the diagonals along 2 axis of the resulting 3-D array).

Asked By: Yandle

||

Answers:

You could use

np.einsum('ijk,kj->ij', A, B)

Another option is

(A * B.T).sum(axis=-1)

but in a few tests of arrays of various sizes, the einsum version was consistently faster.

Answered By: Warren Weckesser
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.