Extracting hidden features from Autoencoders using Pytorch

Question:

Following the tutorials in this post, I am trying to train an autoencoder and extract the features from its hidden layer.

So here are my questions:

  1. In the autoencoder class, there is a "forward" function. However, I cannot see anywhere in the code that this function is called. So how does it get trained?

  2. My question above is because I feel if I want to extract the features, I should add another function (f"orward_hidden") in the autoencoder class:

     def forward(self, features):
         #print("in forward")
         #print(type(features))
         activation = self.encoder_hidden_layer(features)
         activation = torch.relu(activation)
         code = self.encoder_output_layer(activation)
         code = torch.relu(code)
         activation = self.decoder_hidden_layer(code)
         activation = torch.relu(activation)
         activation = self.decoder_output_layer(activation)
         reconstructed = torch.relu(activation)
         return reconstructed
    
     def forward_hidden(self, features):
         activation = self.encoder_hidden_layer(features)
         activation = torch.relu(activation)
         code = self.encoder_output_layer(activation)
         code = torch.relu(code)
         return code
    

Then, after training, which means after this line in the main code:

print("AE, epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs_AE, loss))

I can put the following code to retrieve the features from the hidden layer:

hidden_features = model_AE.forward_hidden(my_input)

Is this way correct? Still, I am wondering how the "forward" function was used for training. Because I cannot see it anywhere in the code that is being called.

Asked By: Kadaj13

||

Answers:

forward is the essence of your model and actually defines what the model does.

It is implicetly called with model(input) during the training.

If you are askling how to extract intermediate features after running the model, you can register a forward-hook like described here, that will "catch" the values for you.

Answered By: Lior Cohen

When creating a class with nn.Module when working with PyTorch, the forward function is called implicitly and you do not need to separately call it.

Answered By: Mohit Gaikwad