Difference between forward and train_step in Pytorch Lightning?

Question:

I have a transfer learning Resnet set up in Pytorch Lightning. the structure is borrowed from this wandb tutorial https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning–VmlldzoyODk1NzY

and from looking at the documentation https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html

I am confused about the difference between the def forward () and the def training_step() methods.

Initially in the PL documentation, the model is not called in the training step, only in forward. But forward is also not called in the training step. I have been running the model on data and the outputs look sensible (I have an image callback and I can see that the model is learning, and getting a good accuracy result at the end). But I am worried that given the forward method is not being called, the model is somehow not being implemented?

Model code is:

class TransferLearning(pl.LightningModule):
    "Works for Resnet at the moment"
    def __init__(self, model, learning_rate, optimiser = 'Adam', weights = [ 1/2288  , 1/1500], av_type = 'macro' ):
        super().__init__()
        self.class_weights = torch.FloatTensor(weights)
        self.optimiser = optimiser
        self.thresh  =  0.5
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        
        #add metrics for tracking 
        self.accuracy = Accuracy()
        self.loss= nn.CrossEntropyLoss()
        self.recall = Recall(num_classes=2, threshold=self.thresh, average = av_type)
        self.prec = Precision( num_classes=2, average = av_type )
        self.jacq_ind = JaccardIndex(num_classes=2)
        

        # init model
        backbone = model
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = nn.Sequential(*layers)

        # use the pretrained model to classify damage 2 classes
        num_target_classes = 2
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        recall = self.recall(preds, y)
        precision = self.prec(preds, y)
        jac = self.jacq_ind(preds, y)

        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        self.log('train_recall', recall, on_step=True, on_epoch=True, logger=True)
        self.log('train_precision', precision, on_step=True, on_epoch=True, logger=True)
        self.log('train_jacc', jac, on_step=True, on_epoch=True, logger=True)
        return loss
  
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)

        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        recall = self.recall(preds, y)
        precision = self.prec(preds, y)
        jac = self.jacq_ind(preds, y)


        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        self.log('val_recall', recall, prog_bar=True)
        self.log('val_precision', precision, prog_bar=True)
        self.log('val_jacc', jac, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        
        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = self.accuracy(preds, y)
        recall = self.recall(preds, y)
        precision = self.prec(preds, y)
        jac = self.jacq_ind(preds, y)


        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        self.log('test_recall', recall, prog_bar=True)
        self.log('test_precision', precision, prog_bar=True)
        self.log('test_jacc', jac, prog_bar=True)


        return loss
    
    def configure_optimizers(self,):
        print('Optimise with {}'.format(self.optimiser) )
        # optimizer = self.optimiser_dict[self.optimiser](self.parameters(), lr=self.learning_rate)
                
                # Support Adam, SGD, RMSPRop and Adagrad as optimizers.
        if self.optimiser == "Adam":
            optimiser = optim.AdamW(self.parameters(), lr = self.learning_rate)
        elif self.optimiser == "SGD":
            optimiser = optim.SGD(self.parameters(), lr = self.learning_rate)
        elif self.optimiser == "Adagrad":
            optimiser = optim.Adagrad(self.parameters(), lr = self.learning_rate)
        elif self.optimiser == "RMSProp":
            optimiser = optim.RMSprop(self.parameters(), lr = self.learning_rate)
        else:
            assert False, f"Unknown optimizer: "{self.optimiser}""

        return optimiser
Asked By: Grace

||

Answers:

self(x) in training_step presents the __call__ function of your class and will use the forward() function.

You can check more details of what happened in self(x) in PyTorch source code: https://github.com/pytorch/pytorch/blob/b6672b10e153b63748874ca9008fd3160f38c3dd/torch/nn/modules/module.py#L1124

Answered By: joe32140

The main difference is in how the outputs of the model are being used.

In Lightning, the idea is that you organize the code in such a way that training logic is separated from inference logic.

forward: Encapsulates the way the model would be used regardless of whether you are training or performing inference.

training_step: Contains all computations necessary to produce a loss value to train the model. Usually there are additional layers like decoders, discriminators, loss functions etc. that are only useful for training and not needed when the trained model is used at inference time. Here we usually also call forward() as well.

The way OP has organized their code is the best practice:

def forward(self, x):
    self.feature_extractor.eval()
    with torch.no_grad():
        representations = self.feature_extractor(x).flatten(1)
    x = self.classifier(representations)
    return x

def training_step(self, batch, batch_idx):
    x, y = batch

    ## self(x) is the same as calling self.forward(x)
    logits = self(x)  
    
    # Loss computation is not part of forward because it's only
    # needed for training
    loss = self.loss(logits, y)

Reference: Introduction to PyTorch Lightning (see section FORWARD vs TRAINING_STEP)

Answered By: awaelchli

I am confused about the difference between the def forward () and the
def training_step() methods.

Quoting from the docs:

"In Lightning we suggest separating training from inference. The training_step defines the full training loop. We encourage users to use the forward to define inference actions."

So forward() defines your prediction/inference actions. It doesn’t even need to be part of your training_step in which you would define you whole training loop. Nonetheless you can choose to have it in your training_step if you want it that way. An example where forward() isn’t part of the training_step:

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # in this case it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

the model is not called in the training step, only in forward. But
forward is also not called in the training step

The fact that forward() is not called in your train_step is because self(x) does it for you. You can alternatively call forward() explicitly instead of using call(x).

I am worried that given the forward method is not being called, the
model is somehow not being implemented?

As long as you see your metrics logged with self.log move in the right direction you will know that you model gets called correctly and its learning.

Answered By: Mike B