How to mask inputs with variable size in transformer model when the batches needs to be masked differently?

Question:

I’m making a transformer using tensorflow.keras and having issues understanding how the attention_mask works for a MultiHeadAttention layer.

My input is 3-dimensional data. For example, let’s assume my whole dataset has 10 elements, each one with length no more than 4:

# whole data
[
  # first item
  [
    [     1,      2,      3],
    [     1,      2,      3],
    [np.nan, np.nan, np.nan],
    [np.nan, np.nan, np.nan],
  ],
  # second item
  [
    [     1,      2,      3],
    [     5,      8,      2],
    [     3,      7,      8],
    [     4,      6,      2],
  ],
  ... # 8 more items
]

So, my mask looks like:

# assume this is a numpy array
mask = [
  [
    [1, 1, 1],
    [1, 1, 1],
    [0, 0, 0],
    [0, 0, 0],
  ],
  [
    [1, 1, 1],
    [1, 1, 1],
    [1, 1, 1],
    [1, 1, 1],
  ],
  ...
]

So the shape of the mask til now is [10, 4, 3]. Let’s say I use batch_size = 5. Now, according documentation, attention_mask shape should be [B, T, S] (batch_size, query_size, key_size). In the example case should be [5, 4, 4]?

Question

If the mask is calculated only once, what 5 items should I give as a mask? This sounds counterintuitive to me. How should I build the mask?

According this answer, head_size should be also taken in account, so they also do:

mask = mask[:, tf.newaxis, tf.newaxis, :]

What I’ve tested

The only time I manage to run the transformer successfully using the attention_mask is when I do:

mask = np.ones((batch_size, data.shape[1], data.shape[2]))
mask = mask[:, tf.newaxis, tf.newaxis, :]

Obviously that mask makes no sense, because it is all ones, but it was just to test if it had the correct shape.

The model

I’m using practically the same code from the keras example transformer for time series classification

def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0.0, mask=None):
    # Normalization and Attention
    x = layers.LayerNormalization(epsilon=1e-6)(inputs)
    x = layers.MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout
    )(x, x, attention_mask=mask)
    x = layers.Dropout(dropout)(x)
    res = x + inputs

    # Feed Forward Part
    x = layers.LayerNormalization(epsilon=1e-6)(res)
    x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(x)
    x = layers.Dropout(dropout)(x)
    x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
    return x + res


def build_model(
    n_classes,
    input_shape,
    head_size,
    num_heads,
    ff_dim,
    num_transformer_blocks,
    mlp_units,
    dropout=0.0,
    mlp_dropout=0.0,
    input_mask=None,
) -> keras.Model:
    inputs = keras.Input(shape=input_shape)
    x = inputs
    for _ in range(num_transformer_blocks):
        x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout, input_mask)

    x = layers.GlobalAveragePooling2D(data_format="channels_first")(x)
    for dim in mlp_units:
        x = layers.Dense(dim, activation="relu")(x)
        x = layers.Dropout(mlp_dropout)(x)
    outputs = layers.Dense(n_classes, activation="softmax")(x)
    return keras.Model(inputs, outputs)
Asked By: Jorge Morgado

||

Answers:

First, a simpler example to understand MultiHeadAttention mask.

#Crude Self attention implementation
query = tf.constant([[1], [2], [3], [4]], dtype=tf.float32)  #Shape([4, 1])

scores = tf.matmul(query, query, transpose_b=True) #Shape([4, 4])
#unnormalized, presoftmax score

The above is the attention scores for the given query. attention_mask is used when you want to prevent attention to certain positions in this score. So mask dimension should be same as the attention score dimension.

Lets say, we decide that the current token in the above example needs to attend to only itself and to the next token, then we can define mask as:

mask = tf.constant([[1., 1., -np.inf, -np.inf],
        [-np.inf, 1., 1. ,-np.inf],
        [-np.inf, -np.inf, 1., 1.],
        [-np.inf, -np.inf, -np.inf, 1.]])

#apply mask on the score
scores = scores*mask

#softmax 
scores = tf.nn.softmax(scores)

#scores, ( 0 indicates no attention) 
[[0.26894143, 0.73105854, 0.        , 0.        ],
 [0.        , 0.11920292, 0.880797  , 0.        ],
 [0.        , 0.        , 0.04742587, 0.95257413],
 [0.        , 0.        , 0.        , 1.        ]]

#score weighted queries
value = tf.matmul(scores, query)

#value is a weighted average of the current and next token of ( [[1], [2], [3], [4]])
[[1.7310585], #weighted average of ([1], [2]) (current and next)
 [2.8807971],
 [3.9525743],
 [4.       ]]

Can there be different mask for each item in the batch?.

Yes, a use case i can think of is when you have padding for different samples in the same batch, so the mask can be set to ignore those paddings.

Your specific case: The mask has to be (batch_size, 4, 4). The mask can be same for each item in the batch.

batch_size = 5
query = keras.Input(shape=(4, 3))
mask_tensor = keras.Input(shape=(4, 4))

#keras layer
mha = keras.layers.MultiHeadAttention(num_heads=1, key_dim=3)
output = mha(query=query, value=query, attention_mask=mask_tensor, return_attention_scores=True)

#Create a model
model = keras.Model([query, mask_tensor], output)

#random query and mask. Note the mask needs to be (1:attention or 0:no attention) 
queries = tf.random.normal(shape=(batch_size, 4, 3))
mask_data = tf.random.uniform(maxval=2, shape=(batch_size, 4, 4), dtype=tf.int32)

#calling the model
values, attn_weights = model.predict([queries, mask_data])

#attm_weights.shape
(5, 1, 4, 4)

After a little research and seeing several transformer model examples this is what solved the problem for me.

  1. Create a custom TransformerBlock layer that supports masking
  2. Add a mask parameter in the call method of the TransformerBlock and reshape it there.
  3. Add a Masking layer before the TransformerBlock

Code:

class TransformerBlock(layers.Layer):
    def __init__(self, head_size, num_heads, ff_dim, ff_dim2, rate=0.1):
        super().__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=head_size)
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)
        self.conv1 = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")
        self.conv2 = layers.Conv1D(filters=ff_dim2, kernel_size=1)
        self.supports_masking = True

    def call(self, inputs, training, mask=None):
        padding_mask = None
        if mask is not None:
            padding_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype="int32")

        out_norm1 = self.layernorm1(inputs, training=training)
        out_att = self.att(
            out_norm1, out_norm1, training=training, attention_mask=padding_mask
        )
        out_drop1 = self.dropout1(out_att, training=training)
        res = out_drop1 + inputs
        out_norm2 = self.layernorm2(res, training=training)
        out_conv1 = self.conv1(out_norm2, training=training)
        out_drop2 = self.dropout2(out_conv1, training=training)
        out_conv2 = self.conv2(out_drop2, training=training)
        return out_conv2 + res

def build_model(
    n_classes,
    input_shape,
    head_size,
    num_heads,
    ff_dim,
    num_transformer_blocks,
    mlp_units,
    dropout=0.0,
    mlp_dropout=0.0,
    mask=None,
) -> keras.Model:
    inputs = keras.Input(shape=input_shape)
    _x = inputs
    if mask is not None:
        _x = layers.Masking(mask_value=mask)(_x)
    for _ in range(num_transformer_blocks):
        _x = TransformerBlock(
            head_size,
            num_heads,
            ff_dim,
            inputs.shape[-1],
            dropout,
        )(_x)

    _x = layers.GlobalAveragePooling2D(data_format="channels_first")(_x)
    for dim in mlp_units:
        _x = layers.Dense(dim, activation="relu")(_x)
        _x = layers.Dropout(mlp_dropout)(_x)
    outputs = layers.Dense(n_classes, activation="softmax")(_x)
    return keras.Model(inputs, outputs)
Answered By: Jorge Morgado