How does torch.einsum perform this 4D tensor multiplication?
Question:
I have come across a code which uses torch.einsum
to compute a tensor multiplication. I am able to understand the workings for lower order tensors, but, not for the 4D tensor as below:
import torch
a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))
c = torch.einsum('nxhd,nyhd->nhxy', [a,b])
print(c.size())
# output: torch.Size([3, 2, 5, 4])
I need help regarding:
- What is the operation that has been performed here (explanation for how the matrices were multiplied/transposed etc.)?
- Is
torch.einsum
actually beneficial in this scenario?
Answers:
(Skip to the tl;dr section if you just want the breakdown of steps involved in an einsum)
I’ll try to explain how einsum
works step by step for this example but instead of using torch.einsum
, I’ll be using numpy.einsum
(documentation), which does exactly the same but I am just, in general, more comfortable with it. Nonetheless, the same steps happen for torch as well.
Let’s rewrite the above code in NumPy –
import numpy as np
a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape
#(3, 2, 5, 4)
Step by step np.einsum
Einsum is composed of 3 steps: multiply
, sum
and transpose
Let’s look at our dimensions. We have a (3, 5, 2, 10)
and a (3, 4, 2, 10)
that we need to bring to (3, 2, 5, 4)
based on 'nxhd,nyhd->nhxy'
1. Multiply
Let’s not worry about the order in which the n,x,y,h,d
axes is, and just worry about the fact if you want to keep them or remove (reduce) them. Writing them down as a table and see how we can arrange our dimensions –
## Multiply ##
n x y h d
--------------------
a -> 3 5 2 10
b -> 3 4 2 10
c1 -> 3 5 4 2 10
To get the broadcasting multiplication between x
and y
axis to result in (x, y)
, we will have to add a new axis at the right places and then multiply.
a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)
c1 = a1*b1
c1.shape
#(3, 5, 4, 2, 10) #<-- (n, x, y, h, d)
2. Sum / Reduce
Next, we want to reduce the last axis 10. This will get us the dimensions (n,x,y,h)
.
## Reduce ##
n x y h d
--------------------
c1 -> 3 5 4 2 10
c2 -> 3 5 4 2
This is straightforward. Lets just do np.sum
over the axis=-1
c2 = np.sum(c1, axis=-1)
c2.shape
#(3,5,4,2) #<-- (n, x, y, h)
3. Transpose
The last step is rearranging the axis using a transpose. We can use np.transpose
for this. np.transpose(0,3,1,2)
basically brings the 3rd axis after the 0th axis and pushes the 1st and 2nd. So, (n,x,y,h)
becomes (n,h,x,y)
c3 = c2.transpose(0,3,1,2)
c3.shape
#(3,2,5,4) #<-- (n, h, x, y)
4. Final check
Let’s do a final check and see if c3 is the same as the c which was generated from the np.einsum
–
np.allclose(c,c3)
#True
TL;DR.
Thus, we have implemented the 'nxhd , nyhd -> nhxy'
as –
input -> nxhd, nyhd
multiply -> nxyhd #broadcasting
sum -> nxyh #reduce
transpose -> nhxy
Advantage
Advantage of np.einsum
over the multiple steps taken, is that you can choose the "path" that it takes to do the computation and perform multiple operations with the same function. This can be done by optimize
paramter, which will optimize the contraction order of an einsum expression.
A non-exhaustive list of these operations, which can be computed by einsum
, is shown below along with examples:
- Trace of an array,
numpy.trace
.
- Return a diagonal,
numpy.diag
.
- Array axis summations,
numpy.sum
.
- Transpositions and permutations,
numpy.transpose
.
- Matrix multiplication and dot product,
numpy.matmul
numpy.dot
.
- Vector inner and outer products,
numpy.inner
numpy.outer
.
- Broadcasting, element-wise and scalar multiplication,
numpy.multiply
.
- Tensor contractions,
numpy.tensordot
.
- Chained array operations, inefficient calculation order,
numpy.einsum_path
.
Benchmarks
%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
It shows that np.einsum
does the operation faster than individual steps.
I have come across a code which uses torch.einsum
to compute a tensor multiplication. I am able to understand the workings for lower order tensors, but, not for the 4D tensor as below:
import torch
a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))
c = torch.einsum('nxhd,nyhd->nhxy', [a,b])
print(c.size())
# output: torch.Size([3, 2, 5, 4])
I need help regarding:
- What is the operation that has been performed here (explanation for how the matrices were multiplied/transposed etc.)?
- Is
torch.einsum
actually beneficial in this scenario?
(Skip to the tl;dr section if you just want the breakdown of steps involved in an einsum)
I’ll try to explain how einsum
works step by step for this example but instead of using torch.einsum
, I’ll be using numpy.einsum
(documentation), which does exactly the same but I am just, in general, more comfortable with it. Nonetheless, the same steps happen for torch as well.
Let’s rewrite the above code in NumPy –
import numpy as np
a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape
#(3, 2, 5, 4)
Step by step np.einsum
Einsum is composed of 3 steps: multiply
, sum
and transpose
Let’s look at our dimensions. We have a (3, 5, 2, 10)
and a (3, 4, 2, 10)
that we need to bring to (3, 2, 5, 4)
based on 'nxhd,nyhd->nhxy'
1. Multiply
Let’s not worry about the order in which the n,x,y,h,d
axes is, and just worry about the fact if you want to keep them or remove (reduce) them. Writing them down as a table and see how we can arrange our dimensions –
## Multiply ##
n x y h d
--------------------
a -> 3 5 2 10
b -> 3 4 2 10
c1 -> 3 5 4 2 10
To get the broadcasting multiplication between x
and y
axis to result in (x, y)
, we will have to add a new axis at the right places and then multiply.
a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)
c1 = a1*b1
c1.shape
#(3, 5, 4, 2, 10) #<-- (n, x, y, h, d)
2. Sum / Reduce
Next, we want to reduce the last axis 10. This will get us the dimensions (n,x,y,h)
.
## Reduce ##
n x y h d
--------------------
c1 -> 3 5 4 2 10
c2 -> 3 5 4 2
This is straightforward. Lets just do np.sum
over the axis=-1
c2 = np.sum(c1, axis=-1)
c2.shape
#(3,5,4,2) #<-- (n, x, y, h)
3. Transpose
The last step is rearranging the axis using a transpose. We can use np.transpose
for this. np.transpose(0,3,1,2)
basically brings the 3rd axis after the 0th axis and pushes the 1st and 2nd. So, (n,x,y,h)
becomes (n,h,x,y)
c3 = c2.transpose(0,3,1,2)
c3.shape
#(3,2,5,4) #<-- (n, h, x, y)
4. Final check
Let’s do a final check and see if c3 is the same as the c which was generated from the np.einsum
–
np.allclose(c,c3)
#True
TL;DR.
Thus, we have implemented the 'nxhd , nyhd -> nhxy'
as –
input -> nxhd, nyhd
multiply -> nxyhd #broadcasting
sum -> nxyh #reduce
transpose -> nhxy
Advantage
Advantage of np.einsum
over the multiple steps taken, is that you can choose the "path" that it takes to do the computation and perform multiple operations with the same function. This can be done by optimize
paramter, which will optimize the contraction order of an einsum expression.
A non-exhaustive list of these operations, which can be computed by einsum
, is shown below along with examples:
- Trace of an array,
numpy.trace
. - Return a diagonal,
numpy.diag
. - Array axis summations,
numpy.sum
. - Transpositions and permutations,
numpy.transpose
. - Matrix multiplication and dot product,
numpy.matmul
numpy.dot
. - Vector inner and outer products,
numpy.inner
numpy.outer
. - Broadcasting, element-wise and scalar multiplication,
numpy.multiply
. - Tensor contractions,
numpy.tensordot
. - Chained array operations, inefficient calculation order,
numpy.einsum_path
.
Benchmarks
%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
It shows that np.einsum
does the operation faster than individual steps.