How does torch.einsum perform this 4D tensor multiplication?

Question:

I have come across a code which uses torch.einsum to compute a tensor multiplication. I am able to understand the workings for lower order tensors, but, not for the 4D tensor as below:

import torch

a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))

c = torch.einsum('nxhd,nyhd->nhxy', [a,b])

print(c.size())

# output: torch.Size([3, 2, 5, 4])

I need help regarding:

  1. What is the operation that has been performed here (explanation for how the matrices were multiplied/transposed etc.)?
  2. Is torch.einsum actually beneficial in this scenario?
Asked By: anurag

||

Answers:

(Skip to the tl;dr section if you just want the breakdown of steps involved in an einsum)

I’ll try to explain how einsum works step by step for this example but instead of using torch.einsum, I’ll be using numpy.einsum (documentation), which does exactly the same but I am just, in general, more comfortable with it. Nonetheless, the same steps happen for torch as well.

Let’s rewrite the above code in NumPy –

import numpy as np

a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape

#(3, 2, 5, 4)

Step by step np.einsum

Einsum is composed of 3 steps: multiply, sum and transpose

Let’s look at our dimensions. We have a (3, 5, 2, 10) and a (3, 4, 2, 10) that we need to bring to (3, 2, 5, 4) based on 'nxhd,nyhd->nhxy'

1. Multiply

Let’s not worry about the order in which the n,x,y,h,d axes is, and just worry about the fact if you want to keep them or remove (reduce) them. Writing them down as a table and see how we can arrange our dimensions –

        ## Multiply ##
       n   x   y   h   d
      --------------------
a  ->  3   5       2   10
b  ->  3       4   2   10
c1 ->  3   5   4   2   10

To get the broadcasting multiplication between x and y axis to result in (x, y), we will have to add a new axis at the right places and then multiply.

a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)

c1 = a1*b1
c1.shape

#(3, 5, 4, 2, 10)  #<-- (n, x, y, h, d)

2. Sum / Reduce

Next, we want to reduce the last axis 10. This will get us the dimensions (n,x,y,h).

          ## Reduce ##
        n   x   y   h   d
       --------------------
c1  ->  3   5   4   2   10
c2  ->  3   5   4   2

This is straightforward. Lets just do np.sum over the axis=-1

c2 = np.sum(c1, axis=-1)
c2.shape

#(3,5,4,2)  #<-- (n, x, y, h)

3. Transpose

The last step is rearranging the axis using a transpose. We can use np.transpose for this. np.transpose(0,3,1,2) basically brings the 3rd axis after the 0th axis and pushes the 1st and 2nd. So, (n,x,y,h) becomes (n,h,x,y)

c3 = c2.transpose(0,3,1,2)
c3.shape

#(3,2,5,4)  #<-- (n, h, x, y)

4. Final check

Let’s do a final check and see if c3 is the same as the c which was generated from the np.einsum

np.allclose(c,c3)

#True

TL;DR.

Thus, we have implemented the 'nxhd , nyhd -> nhxy' as –

input     -> nxhd, nyhd
multiply  -> nxyhd      #broadcasting
sum       -> nxyh       #reduce
transpose -> nhxy

Advantage

Advantage of np.einsum over the multiple steps taken, is that you can choose the "path" that it takes to do the computation and perform multiple operations with the same function. This can be done by optimize paramter, which will optimize the contraction order of an einsum expression.

A non-exhaustive list of these operations, which can be computed by einsum, is shown below along with examples:

  • Trace of an array, numpy.trace.
  • Return a diagonal, numpy.diag.
  • Array axis summations, numpy.sum.
  • Transpositions and permutations, numpy.transpose.
  • Matrix multiplication and dot product, numpy.matmul numpy.dot.
  • Vector inner and outer products, numpy.inner numpy.outer.
  • Broadcasting, element-wise and scalar multiplication, numpy.multiply.
  • Tensor contractions, numpy.tensordot.
  • Chained array operations, inefficient calculation order, numpy.einsum_path.

Benchmarks

%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)

It shows that np.einsum does the operation faster than individual steps.

Answered By: Akshay Sehgal