How to change activation layer in Pytorch pretrained module?

Question:

How to change the activation layer of a Pytorch pretrained network?
Here is my code :

print("All modules")
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)

print('Before changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
        child=nn.SELU()
        print(child)
print('after changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)

Here is my output:

All modules
ReLU(inplace=True)
Before changing activation
ReLU(inplace=True)
SELU()
after changing activation
ReLU(inplace=True)
Asked By: Hamdard

||

Answers:

._modules solves the problem for me.

for name,child in net.named_children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        net._modules['relu'] = nn.SELU()
Answered By: Hamdard

I’m assuming you use module interface nn.ReLU to create the acitvation layer instead of using functional interface F.relu. If so, setattr works for me.

import torch
import torch.nn as nn

# This function will recursively replace all relu module to selu module. 
def replace_relu_to_selu(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            setattr(model, child_name, nn.SELU())
        else:
            replace_relu_to_selu(child)

########## A toy example ##########
net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 32, kernel_size=3, stride=1),
            nn.ReLU(inplace=True)
          )

########## Test ##########
print('Before changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
# Before changing activation
# ReLU(inplace=True)
# ReLU(inplace=True)


print('after changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
# after changing activation
# SELU()
# SELU(
Answered By: zihaozhihao

I will provide a more general solution that works for any layer (and avoids other issues like modifying a dictionary as you loop through it or when there are recursive nn.modules inside each other).

def replace_bn(module, name):
    '''
    Recursively put desired batch norm in nn.module module.

    set module = net to start code.
    '''
    # go through all attributes of module nn.module (e.g. network or layer) and put batch norms if present
    for attr_str in dir(module):
        target_attr = getattr(m, attr_str)
        if type(target_attr) == torch.nn.BatchNorm2d:
            print('replaced: ', name, attr_str)
            new_bn = torch.nn.BatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine,
                                          track_running_stats=False)
            setattr(module, attr_str, new_bn)

    # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
    for name, immediate_child_module in module.named_children():
        replace_bn(immediate_child_module, name)

replace_bn(model, 'model')

the crux is that you need to recursively keep changing the layers (mainly because sometimes you will encounter attributes that have modules itself). I think better code than the above would be to add another if statement (after the batch norm) detecting if you have to recurse and recursing if so. The above works to but first changes the batch norm over the outer layer (i.e. the first loop) and then with another loop making sure no other object that should be recursed is missed (and then recursing).

Original post: https://discuss.pytorch.org/t/how-to-modify-a-pretrained-model/60509/10

credits: https://discuss.pytorch.org/t/replacing-convs-modules-with-custom-convs-then-notimplementederror/17736/3?u=brando_miranda

Answered By: Charlie Parker

here is a general function for replacing any layer

def replace_layers(model, old, new):
    for n, module in model.named_children():
        if len(list(module.children())) > 0:
            ## compound module, go inside it
            replace_layers(module, old, new)
            
        if isinstance(module, old):
            ## simple module
            setattr(model, n, new)

replace_layer(model, nn.ReLU, nn.ReLU6())

I struggled with it for a few days. So, I did some digging & wrote a kaggle notebook explaining how different types of layers / modules are accessed in pytorch.

Answered By: Ankur Singh

Works fine for me with default pytorch API:

def replace_layer(module: nn.Module, old: nn.Module, new: nn.Module, full_name=""):
    for name, m in module.named_children():
        full_name = f"{full_name}.{name}"

        if isinstance(m, old):
            setattr(module, name, new)
            print(f"replaced {full_name}: {old}->{new}")
        elif len(list(m.children())) > 0:
            replace_layer(m, old, new, full_name)

model.apply(lambda m: replace_layer(m, nn.Relu, nn.Hardswish(True)))
Will repalce layers and print "trace":

replaced ._model.norm_layer.0.1.2: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.3.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.3.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.3.4.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
Answered By: RedEyed