What it means when your model can't overfit a small batch of data?

Question:

I am trying to train RNN model to classify sentences into 4 classes, but it doesn’t seem to work. I tried to overfit 4 examples (blue line) which worked, but even as little as 8 examples (red line) is not working, let alone the whole dataset.enter image description here

I tried different learning rates and sizes of hidden_size and embedding_size but it doesn’t seem to help, what am I missing? I know that if the model is not able to overfit small batch it means the capacity should be increased but in this case increasing capacity has no effect.
The architecture is as follows:

class RNN(nn.Module):
    def __init__(self, embedding_size=256, hidden_size=128, num_classes=4):
        super().__init__()
        self.embedding = nn.Embedding(len(vocab), embedding_size, 0)
        self.rnn = nn.RNN(embedding_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        #x=[batch_size, sequence_length]
        x = self.embedding(x) #x=[batch_size, sequence_length, embedding_size]
        _, h_n = self.rnn(x)  #h_n=[1, batch_size, hidden_size]
        h_n = h_n.squeeze(0)
        out = self.fc(h_n)  #out=[batch_size, num_classes]
        return out

Input data is tokenized sentences, padded with 0 to the longest sentence in the batch, so as an example one sample would be: [2784, 9544, 1321, 120, 0, 0]. The data is from AG_NEWS dataset from torchtext datasets.

The training code:

model = RNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
model.train()

for epoch in range(NUM_EPOCHS):
    epoch_losses = []
    correct_predictions = []
    for batch_idx, (labels, texts) in enumerate(train_loader):
        scores = model(texts)
        loss = criterion(scores, labels)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
        epoch_losses.append(loss.item())
        correct = (scores.max(1).indices==labels).sum()
        correct_predictions.append(correct)
        
    epoch_avg_loss = sum(epoch_losses)/len(epoch_losses)
    epoch_avg_accuracy = float(sum(correct_predictions))/float(len(labels))
Asked By: SlimakSlimak

||

Answers:

The issue was due to the vanishing gradient.

Answered By: SlimakSlimak