PytorchLightning : Model calls order

Question:

I’m trying to reimplement a training pipeline on top of pytorch Lightning.

In the documentation they explain that training/validation loops are executed this way :
enter image description here
enter image description here

My understanding was that the order was :

  • train_step()
  • train_epoch_end()
  • val_step()
  • val_epoch_end()

I’ve implemented a dummy code in order to check this :


import pytorch_lightning as pl
from torchmetrics import MeanMetric, SumMetric
from torch.utils.data import Dataset,DataLoader
import torch
import warnings
warnings.filterwarnings("ignore")

class DummyDataset(Dataset):
    def __init__(self):
        pass
    def __getitem__(self,idx):
        return torch.zeros([3,12,12]),torch.ones([3,12,12]) # Dummy image Like...
    def __len__(self):
        return 50

class DummyModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3,3,1,1) # Useless convolution
        self.mean = MeanMetric()
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr=1e-3)
    def training_step(self, batch,batch_idx):
        x,y=batch
        y_hat = self(x)
        loss = torch.sum((y-y_hat)**2)
        self.mean.update(2)
        return loss

    def training_epoch_end(self, outputs):
        mean_train = self.mean.compute()
        print(f"nmean_train is : {mean_train}n")
        self.mean.reset()

    def validation_step(self, batch,batch_idx):
        x,y=batch
        y_hat = self(x)
        loss = torch.sum((y-y_hat)**2)
        self.mean.update(4)
        return loss

    def validation_epoch_end(self, outputs):
        mean_val = self.mean.compute()
        print(f"nmean_val is : {mean_val}n")
        self.mean.reset()

    def forward(self,x):
        return self.conv(x)

if __name__=='__main__':
    dataset = DummyDataset()
    train_loader=DataLoader(dataset,batch_size=4,num_workers=0)
    val_loader=DataLoader(dataset,batch_size=4,num_workers=0)
    model = DummyModel()
    # We create trainer
    trainer = pl.Trainer(val_check_interval=None)
    # We fit model
    trainer.fit(model,train_dataloaders=train_loader,val_dataloaders=val_loader)

What i see in the output is :

  • mean_val is : 3
  • mean_train is : nan

It is coherent with what i see with the debugger and the order is :

  • train_step()
  • val_step()
  • val_epoch_end()
  • train_epoch_end()

Is it the case ?
Did i something wrong ?
How does it work ?
Thanks !

Asked By: FrsECM

||

Answers:

The sequence you observe is correct. Here is a sketch of how it is implemented:

for epoch in range(max_epocks):
    for i, batch in enumerate(train_dataloader):
    
        model.training_step(batch, i)
        
        if should_validate():
            for i, batch in enumerate(val_dataloader):
                 model.validation_step(i, batch)
            model.validation_epoch_end()
    
     model.training_epoch_end()

As you can see, the validation loop is inside the training loop, and can potentially trigger on a batch level. This is can be configured in the Trainer via Trainer(val_check_interval=x) where x means every x batches.

But by default it will validate every epoch, which means every len(train_dataloader), and thus the should_validate condition is true on the very last batch of the epoch. This is why you see in your prints:

val_epoch_end()
train_epoch_end()

(they basically happen at the same time).

I hope this explanation helps.

Answered By: awaelchli

The calling sequence of a Trainer in lightning. It includes some callback functions.

The calling sequence of a Trainer. It includes some callback functions

Answered By: Frank Xu
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.