How does it work a Multi-Layer GRU/LSTM in Pytorch

Question:

I’m trying to understand exactly how the calculation are performed in the GRU pytorch class. I’m having some troubles while reading the GRU pytorch documetation and the LSTM TorchScript documentation with its code implementation.

In the GRU documentation is stated:

In a multilayer GRU, the input xt(l)​ of the l -th layer (l>=2) is the hidden state ht(l−1)​ of the previous layer multiplied by dropout δt(l−1)​where each ​δt(l−1)​ is a Bernoulli random variable which is 0 with probability dropout.

So essentially given a sequence, each time point should be passed through all the layers for each loop, like this implementation

Meanwhile the LSTM code implementation is:

def script_lstm(input_size, hidden_size, num_layers, bias=True,
                batch_first=False, dropout=False, bidirectional=False):
    '''Returns a ScriptModule that mimics a PyTorch native LSTM.'''

    # The following are not implemented.
    assert bias
    assert not batch_first

    if bidirectional:
        stack_type = StackedLSTM2
        layer_type = BidirLSTMLayer
        dirs = 2
    elif dropout:
        stack_type = StackedLSTMWithDropout
        layer_type = LSTMLayer
        dirs = 1
    else:
        stack_type = StackedLSTM
        layer_type = LSTMLayer
        dirs = 1

    return stack_type(num_layers, layer_type,
                      first_layer_args=[LSTMCell, input_size, hidden_size],
                      other_layer_args=[LSTMCell, hidden_size * dirs,
                                        hidden_size])
class LSTMCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size))
        self.bias_ih = Parameter(torch.randn(4 * hidden_size))
        self.bias_hh = Parameter(torch.randn(4 * hidden_size))

    @jit.script_method
    def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        hx, cx = state
        gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih +
                 torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, (hy, cy)
class LSTMLayer(jit.ScriptModule):
    def __init__(self, cell, *cell_args):
        super(LSTMLayer, self).__init__()
        self.cell = cell(*cell_args)

    @jit.script_method
    def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        inputs = input.unbind(0)
        outputs = torch.jit.annotate(List[Tensor], [])
        for i in range(len(inputs)):
            out, state = self.cell(inputs[i], state)
            outputs += [out]
        return torch.stack(outputs), state
def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args):
    layers = [layer(*first_layer_args)] + [layer(*other_layer_args)
                                           for _ in range(num_layers - 1)]
    return nn.ModuleList(layers)


class StackedLSTM(jit.ScriptModule):
    __constants__ = ['layers']  # Necessary for iterating through self.layers

    def __init__(self, num_layers, layer, first_layer_args, other_layer_args):
        super(StackedLSTM, self).__init__()
        self.layers = init_stacked_lstm(num_layers, layer, first_layer_args,
                                        other_layer_args)

    @jit.script_method
    def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]:
        # List[LSTMState]: One state per layer
        output_states = jit.annotate(List[Tuple[Tensor, Tensor]], [])
        output = input
        # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471
        i = 0
        for rnn_layer in self.layers:
            state = states[i]
            output, out_state = rnn_layer(output, state)
            output_states += [out_state]
            i += 1
        return output, output_states

So in this case each layer does its own sequence for loop and passes another sequence tensor to the next layer.

So my question is: Which is the correct way to implement a multi-layer GRU?

Asked By: Marcelaus

||

Answers:

In the Pytorch GRU Document, you would find that it contains an attribute named num_layers which allows you to specify the number of GRU layers.

If this answers your question as to how we apply the GRU layers practically?

>>> rnn = nn.GRU(input_size = 10, hidden_size = 20, num_layers = 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> output, hn = rnn(input, h0)
Answered By: Leo

I think you are misunderstanding the definition. The approach that you see in the lstm code, where each layer passes an entire sequence on to the next, is the standard approach for stacked RNN’s – at least for sequence to sequence models. It’s equivalent to RNN(RNN(input)).

It’s also what the PyTorch GRU definition is saying, albeit, in a somewhat round-about-way. The definition is saying that for the N-th layer GRU, the input i, is the hidden state h (read: output) of the (N-1)-th layer GRU. Now, in theory, we could run all the inputs one at a time through all the layers and collect the outputs. Or we can do the entire sequence for each layer and only keep the last output sequence. This second approach should be faster, because it allows for vectorizing the calculations more efficiently.

Further, if you look at the link you sent with the two different GRU models. You’ll see that the results are equivalent, whether you run the inputs through each layer one at a time using GRUCell’s, or use full GRU layers.

Answered By: Sean