PyTorch get all layers of model

Question:

What’s the easiest way to take a pytorch model and get a list of all the layers without any nn.Sequence groupings? For example, a better way to do this?

import pretrainedmodels

def unwrap_model(model):
    for i in children(model):
        if isinstance(i, nn.Sequential): unwrap_model(i)
        else: l.append(i)

model = pretrainedmodels.__dict__['xception'](num_classes=1000, pretrained='imagenet')
l = []
unwrap_model(model)            
            
print(l)
    
Asked By: Austin

||

Answers:

You can iterate over all modules of a model (including those inside each Sequential) with the modules() method. Here’s a simple example:

>>> model = nn.Sequential(nn.Linear(2, 2), 
                          nn.ReLU(),
                          nn.Sequential(nn.Linear(2, 1),
                          nn.Sigmoid()))

>>> l = [module for module in model.modules() if not isinstance(module, nn.Sequential)]

>>> l

[Linear(in_features=2, out_features=2, bias=True),
 ReLU(),
 Linear(in_features=2, out_features=1, bias=True),
 Sigmoid()]
Answered By: Andreas K.

I do it like this:

def flatten(el):
    flattened = [flatten(children) for children in el.children()]
    res = [el]
    for c in flattened:
        res += c
    return res

cnn = nn.Sequential(Custom_block_1, Custom_block_2)
layers = flatten(cnn)
Answered By: Etienne D

I netted it for a deeper model and not all blocks were from nn.sequential.

def get_children(model: torch.nn.Module):
    # get children form model!
    children = list(model.children())
    flatt_children = []
    if children == []:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))
    return flatt_children
Answered By: Kees

In case you want the layers in a named dict, this is the simplest way:

named_layers = dict(model.named_modules())

This returns something like:

{
    'conv1': <some conv layer>,
    'fc1': < some fc layer>,
     ### and other layers 
}

Example:

import torchvision.models as models

model = models.inception_v3(pretrained = True)
named_layers = dict(model.named_modules())
Answered By: Mayukh Deb

If you want a nested dictionary with names as keys and modules as values, e.g.:

{'conv1': Conv2d(...),
 'bn1': BatchNorm2d(...),
 'block1':{
    'group1':{
        'conv1': Conv2d(...),
        'bn1': BatchNorm2d(...),
        'conv2': Conv2d(...),
        'bn2': BatchNorm2d(...),
    },
    'group2':{ ...
    }, ...
}

You can combine the answers of Kees and Mayukh Deb to get:

def nested_children(m: torch.nn.Module):
    children = dict(m.named_children())
    output = {}
    if children == {}:
        # if module has no children; m is last child! :O
        return m
    else:
        # look for children from children... to the last child!
        for name, child in children.items():
            try:
                output[name] = nested_children(child)
            except TypeError:
                output[name] = nested_children(child)
    return output
Answered By: Jetze Schuurmans

here’s my method, you can generally input any model here and it will return a list of all torch.nn.* things

def flatten_model(modules):
    def flatten_list(_2d_list):
        flat_list = []
        # Iterate through the outer list
        for element in _2d_list:
            if type(element) is list:
                # If the element is of type list, iterate through the sublist
                for item in element:
                    flat_list.append(item)
            else:
                flat_list.append(element)
        return flat_list

    ret = []
    try:
        for _, n in modules:
            ret.append(loopthrough(n))
    except:
        try:
            if str(modules._modules.items()) == "odict_items([])":
                ret.append(modules)
            else:
                for _, n in modules._modules.items():
                    ret.append(loopthrough(n))
        except:
            ret.append(modules)
    return flatten_list(ret)

Expanding the answer https://stackoverflow.com/a/69544742/429476 from Ivan

target_layers =[]
module_list =[module for module in model.modules()] # this is needed
flatted_list= flatten_model(module_list)

for count, value in enumerate(flatted_list):
    
    if isinstance(value, (nn.Conv2d,nn.AvgPool2d,nn.BatchNorm2d)):
    #if isinstance(value, (nn.Conv2d)):
        print(count, value)
        target_layers.append(value)

Result for ResNet50

1 Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
2 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
7 Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
8 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
9 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
10 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
11 Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
12 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
15 Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
16 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
18 Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
19 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
20 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
21 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
22 Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
23 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
26 Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
27 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
28 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
29 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
30 Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
31 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
35 Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
36 BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
37 Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
38 BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
39 Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
40 BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
43 Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
44 BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
46 Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
47 BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
48 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
49 BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
50 Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
51 BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
54 Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
55 BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
56 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
57 BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
58 Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
59 BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
62 Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
63 BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
64 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
65 BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
66 Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
67 BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
71 Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
72 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
73 Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
74 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
75 Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
76 BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
79 Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
80 BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
82 Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
83 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
84 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
85 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
86 Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
87 BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
90 Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
91 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
92 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
93 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
94 Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
95 BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
98 Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
99 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
100 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
101 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
102 Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
103 BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
106 Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
107 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
108 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
109 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
110 Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
111 BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
114 Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
115 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
116 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
117 BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
118 Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
119 BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
123 Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
124 BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
125 Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
126 BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
127 Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
128 BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
131 Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
132 BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
134 Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
135 BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
136 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
137 BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
138 Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
139 BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
142 Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
143 BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
144 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
145 BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
146 Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
147 BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Answered By: Alex Punnen
#here is my approach:
for name, m in model.named_modules():
    if len(list(m.named_modules()))==1:
        print(name,"t",m)
Answered By: Yales Peter

Here is how I would recursively get all layers:

def get_layers(model: torch.nn.Module):
    children = list(model.children())
    return [model] if len(children) == 0 else [ci for c in children for ci in get_layers(c)]
Answered By: user2648582

The simplest way to just get the layers would be

for module_name, module in model.named_modules():
    print(f"module_name : {module_name} , value : {module}")

for example for resnet 18

import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True).to(device = device,non_blocking=True)

for module_name, module in model.named_modules():
    print(f"module_name : {module_name} , value : {module}")

would yield results like (note that I am not printing the "module", just the names)

conv1
bn1
layer1
layer1.0
layer1.0.relu
layer1.0.conv2
layer1.0.bn2
layer1.1
layer1.1.conv2
layer1.1.bn2
Answered By: Kalyan Sekhar

This gets you all the layers.

def flatten(model):
    submodules = list(model.children())
    if len(submodules) == 0:
        return [model]
    else:
        res = []
        for module in submodules:
            res += flatten(module)
        return res
Answered By: Godzilla
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.