Understanding key_dim and num_heads in tf.keras.layers.MultiHeadAttention

Question:

For example, I have input with shape (1, 1000, 10) (so, src.shape wil be (1, 1000, 10)). Then:

  • This works
class Model(tf.keras.Model):
        def __init__(self):
            super(Model, self).__init__()
            self.attention1 = tf.keras.layers.MultiHeadAttention(num_heads=20, key_dim=9)
            self.dense = tf.keras.layers.Dense(10, activation="softmax")

        def call(self, src):
            output = self.attention1(src, src)
            output = tf.reshape(output, [1, 10000])
            output = self.dense(output)
            return output
  • And this:
class Model(tf.keras.Model):
        def __init__(self):
            super(Model, self).__init__()
            self.attention1 = tf.keras.layers.MultiHeadAttention(num_heads=123, key_dim=17)
            self.dense = tf.keras.layers.Dense(10, activation="softmax")

        def call(self, src):
            output = self.attention1(src, src)
            output = tf.reshape(output, [1, 10000])
            output = self.dense(output)
            return output

So, this layer works with whatever num_heads and key_dim but secuence length (i.e. 1000) should be divisible by num_heads. WHY? Is it a bug? For example, the same code for Pytorch doesn’t work. Also, what is a key_dim then… Thanks in advance.

Asked By: Alex

||

Answers:

There are two dimensions d_k and d_v in the original paper.

  • key_dim corresponds to d_k, which is the size of the key and query dimensions for each head. d_k can be more or less than d_v.

  • d_v = embed_dim/num_head. d_v is the size of the value for each head. Strictly speaking, d_v = embed_dim/num_head is not required. It is however typical for this to be the case in a Transformer so that the concatenation of values across the heads leads to a vector of the same size as the original embedding.

In their paper, Vaswani et al. set d_k = d_v. This, however, is also not required. Conceptually, you can have d_k << d_v or even d_k >> d_v. In the former, you will have dimensionality reduction for each key/query in each head and in the latter, you will have dimensionality expansion for each key/query in each attention head. The change in dimension is transparently handled in the dimensionality of the weight matrix that is multiplied into each query/key/value.

Answered By: Anirban Mukherjee