How to unfreeze layers from a densenet? (PyTorch)

Question:

I’d like to perform fine-tuning of an entire block from DenseNet-161. At the moment, I know I can use the following to freeze all layers apart from the classifier:

model = models.densenet161(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
num_ftrs = model.classifier.in_features
    
model.classifier = torch.nn.Linear(num_ftrs,2)

However, I’d like to unfreeze the last few layers/ block of the DenseNet for fine-tuning. What would be the best most elegant way of achieving this?

Asked By: TSRAI

||

Answers:

First of all, you can also unfreeze the classifier by setting requires_grad of it’s parameters to True.

for param in model.classifier.parameters():
    param.requires_grad = True

This way you keep the original parameters of that layer, instead of a new random initialization that you get when create a new nn.Linear.

That also works for any other submodule of the DenseNet. You can see which other modules there are by printing the module.
To unfreeze the last block and the last BatchNorm, you can do

# this is a torch.nn.Sequential containing the 
# "denseblock4" and "norm5" submodules
submodules = model.features[-2:]  
for param in submodules.parameters():
    param.requires_grad = True

If you want to reset the parameters to a new random initialization, you can use some initializer from torch.nn.init on each parameter.


As requested in the comments: How to re-initialize the last two layers while keeping them frozen?

The last two layers contain convolutional layers and batch norm layers. While you probably want to reinitialize the convolutional layers randomly, this may not be what you want for the batch norm layers.

with torch.no_grad():  # allows to re-initialize the parameters
    submodules = model.features[-2:] 
    for submodule in submodules.modules():
        if isinstance(submodule, torch.nn.Conv2d):
            # randomly re-initialize the weights
            torch.nn.init.kaiming_normal_(submodule.weight)
            if submodule.bias is not None:
                # reset the bias to zero
                torch.nn.init.zeros_(submodule.bias)
        elif isinstance(submodule, torch.nn.BatchNorm2d):
            torch.nn.init.ones_(submodule.weight)
            torch.nn.init.zeros_(submodule.bias)
            # also reset running mean and running_var
            torch.nn.init.zeros_(submodule.running_mean)
            torch.nn.init.ones_(submodule.running_var)

We haven’t frozen or un-frozen the parameters in this code. They retain the state they had initially. You can either freeze them before or afterwards using the usual procedure.

Answered By: cherrywoods