Tensorflow – Decoder for Machine Translation

Question:

I am going through Tensorflow’s tutorial on Neural Machine Translation using Attention mechanism.

It has the following code for the Decoder :

class Decoder(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
    super(Decoder, self).__init__()
    self.batch_sz = batch_sz
    self.dec_units = dec_units
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(self.dec_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')
    self.fc = tf.keras.layers.Dense(vocab_size)

    # used for attention
    self.attention = BahdanauAttention(self.dec_units)

  def call(self, x, hidden, enc_output):
    # enc_output shape == (batch_size, max_length, hidden_size)
    context_vector, attention_weights = self.attention(hidden, enc_output)

    # x shape after passing through embedding == (batch_size, 1, embedding_dim)
    x = self.embedding(x)

    # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
    x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

    # passing the concatenated vector to the GRU
    output, state = self.gru(x)

    # output shape == (batch_size * 1, hidden_size)
    output = tf.reshape(output, (-1, output.shape[2]))

    # output shape == (batch_size, vocab)
    x = self.fc(output)

    return x, state, attention_weights

What I don’t understand here is that, the GRU cell of the decoder is not connected to the encoder by initializing it with the last hidden state of the encoder.

output, state = self.gru(x)  

# Why is it not initialized with the hidden state of the encoder ? 

As per my understanding, there is a connection between the encoder and decoder, only when the decoder is initialized with the "Thought vector" or the last hidden state of the encoder.

Why is that missing in Tensorflow’s official tutorial ? Is it a bug ? Or am I missing something here ?

Could someone help me understand ?

Asked By: AnonymousMe

||

Answers:

This is very well summarized by this detailed NMT guide, which compares the classic seq2seq NMT against the encoder-decoder attention-based NMT architectures.

Vanilla seq2seq: The decoder also needs to have access to the source information, and one simple way to achieve that is to initialize it with the last hidden state of the encoder, encoder_state.

Attention-based encoder-decoder: Remember that in the vanilla seq2seq model, we pass the last source state from the encoder to the decoder when starting the decoding process. This works well for short and medium-length sentences; however, for long sentences, the single fixed-size hidden state becomes an information bottleneck. Instead of discarding all of the hidden states computed in the source RNN, the attention mechanism provides an approach that allows the decoder to peek at them (treating them as a dynamic memory of the source information). By doing so, the attention mechanism improves the translation of longer sentences.

In both cases, you can use teacher forcing to better train the model.

TLDR; the attention mechanism is what helps the decoder "peak" into the encoder instead of you explicitly passing what the encoder is doing to the decoder.

Answered By: Akshay Sehgal