How to replace this naive code with scaled_dot_product_attention() in Pytorch?

Question:

Consider a code fragment from Crossformer:

def forward(self, queries, keys, values):
    B, L, H, E = queries.shape
    _, S, _, D = values.shape
    scale = self.scale or 1./sqrt(E)

    scores = torch.einsum("blhe,bshe->bhls", queries, keys)
    A = self.dropout(torch.softmax(scale * scores, dim=-1))
    V = torch.einsum("bhls,bshd->blhd", A, values)
    
    return V.contiguous()

I’m trying to accelerate it by replacing the naive calls with Flash Attention. For that, I did the following:

def forward(self, queries, keys, values):
    # I'm not sure about the below - it's just a ChatGPT-assisted guess
    # B represents the batch size.
    # L is the sequence length for queries (or target sequence length).
    # H is the number of attention heads.
    # E is the depth (dimension) of each attention head for queries/keys.
    # S is the sequence length for keys/values (or source sequence length).
    # D is the depth (dimension) of each attention head for values.
    B, L, H, E = queries.shape
    _, S, _, D = values.shape

    y = torch.nn.functional.scaled_dot_product_attention(
        queries, keys, values, dropout_p=self.dropout_p if self.training else None)
    y = y.contiguous()
    return y

However, with the above code, I’m getting the following error:

RuntimeError: The size of tensor a (10) must match the size of tensor b (4) at 
non-singleton dimension 1

The debugger shows me the following tensor sizes:

  • keys: (2048, 4, 16, 32)
  • queries: (2048, 10, 16, 32)
  • values: (2048, 4, 16, 32)

What am I missing in this change?

Asked By: Serge Rogatch

||

Answers:

The sequence dimension must be at dimension -2 (see the documentation).

Thus you must transpose dimension 1 with dimension 2 in your case:

y = torch.nn.functional.scaled_dot_product_attention(
   queries.transpose(1, 2),
   keys.transpose(1, 2),
   values.transpose(1, 2),
   dropout_p=self.dropout_p if self.training else 0
).transpose(1, 2)
y = y.contiguous()
return y

Also remark that the dropout must be a number (0 when not applied).

Answered By: ftorre