Understanding fancy einsum equation

Question:

I was reading about attention and came across this equation:

import einops
from fancy_einsum import einsum
import torch

x = torch.rand((200, 10, 768))
y = torch.rand((20, 768, 64))

res = einsum("batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head", x, y)

And I am not able to understand the underlying operations that give the result res

I thought it might be matmul and tried this:

import torch
x_ = x.unsqueeze(dim = 2).unsqueeze(dim = 2)
y_ = torch.broadcast_to(y, (1, 1, 20, 768, 64))
res2 = x_ @ y_
res2 = res2.squeeze(dim = -2)
(res == res2).all() # Prints False

But that does not seem to be right.

Any help regarding this is greatly appreciated

Asked By: Sai Prashanth

||

Answers:

So whenever using einsum you best think about the meaning of the dimensions. Basically we perform a multiplication between the two inputs in this case. The signature passed to einsum shows what dimensions will be preserved and which ones will be "summed away". I simplified the signature with single letters here:

res = einsum("b q m, n m h -> b q n h", x, y)

We can read from this that both x and y have three dimensions. Furthermore both have a dimension called m, and this doesn’t appear in the output. So we can conclude that it gets "summed away". So for each entry of the output we have following formula. For simplicity I reused the dimension names as indices, so for every b,q,n,h we get

               ___
                
res[b,q,n,h] =  /  x[b,q,m] * y[n,m,h]
               /__
                m

To do this with any other function than einsum is usually more cumbersome. So first we need to reorder and unsqueeze the dimensions in a way that they are compatible to be multiplied, so we can do the following (the shapes annotated above):

 #(b,q,m,n,h)   (b, q, m, 1,    1)      (m, n, h)
 product     = x[:, :, :, None, None] * y.permute([1,0,2]) 

Due to the broadcasting rules, the second (y-) term will implicitly get the required leading dummy dimensions.
Then we can "sum away" the dimension m:

 res = product.sum(dim=2)  # (b,q,n,h)

So you can interpret that as a matrix multiplication if you want, or also just a scalar product, but of course with many "batch"-dimensions.

Answered By: flawr