self() as function within class, what does it do?

Question:

Sorry for the poor title but I’m unsure how better to describe the question.

So I recently watched Andrej Kaparthy’s build GPT video which is awesome and now trying to reconstruct the code myself I notices that he uses self() as a function and was curious why and what exactly it does.

The code is here and I’m curious in particular about the generate function:

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

So to me it seems that he is calling the forward function defined within the class through using the self(). Is that correct? And if so why would he not use forward(idx) instead. Thank you for you help!

Asked By: IloveR

||

Answers:

Meh, this is pytorch. Remember that you can use the model like this: model(x) to do the model.forward(x). So inside of the model class self(x) will be the basically the same as doing self.forward(x).

Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.