How to calculate dimensions of first linear layer of a CNN

Question:

Currently, I am working with a CNN where there is a fully connected layer attached to it and I am working with a 3 channel image of size 32×32. I am wondering on if there is a consistent formula I can use to calculate the input dimensions of the first linear layer with the input from the last conv/maxpooling layer. I want to be able to calculate the dimensions of the first linear layer given only information of the last conv2d layer and maxpool later. In other words, I would like to be able to calculate that value without having to use information of the previous layers before (so I don’t have to manually calculate weight dimensions of a very deep network)

I also want to understand the calculation of acceptable dimensions, like what would be the reasoning of those calculations?

For some reason these calculations work and Pytorch accepted these dimensions:

val = int((32*32)/4)
self.fc1 = nn.Linear(val, 200)

and this also worked

self.fc1 = nn.Linear(64*4*4, 200)

Why do those values work, and is there a limitation in the calculation of those methods? I feel like this would break if I were to change stride distance or kernel size, for example.

Here is the general model architecture I was working with:

# define the CNN architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # convolutional layer
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
        # max pooling layer
        self.pool = nn.MaxPool2d(2, 2)  


        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32,kernel_size=3)
        self.pool2 = nn.MaxPool2d(2,2)

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.pool3 = nn.MaxPool2d(2,2)
        
        self.dropout = nn.Dropout(0.25)

        # H*W/4
        val = int((32*32)/4)
        #self.fc1 = nn.Linear(64*4*4, 200)
        ################################################
        self.fc1 = nn.Linear(val, 200)  # dimensions of the layer I wish to calculate
        ###############################################
        self.fc2 = nn.Linear(200,100)
        self.fc3 = nn.Linear(100,10)


    def forward(self, x):
        # add sequence of convolutional and max pooling layers
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        #print(x.shape)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)

        return x

# create a complete CNN
model = Net()
print(model)

Can anyone tell me how to calculate the dimensions of the first linear layer and explain the reasoning?

Asked By: Vanstorm

||

Answers:

Given the input spatial dimension w, a 2d convolution layer will output a tensor with the following size on this dimension:

int((w + 2*p - d*(k - 1) - 1)/s + 1)

The exact same is true for nn.MaxPool2d. For reference, you can look it up here, on the PyTorch documentation.

The convolution part of your model is made up of three (Conv2d + MaxPool2d) blocks. You can easily infer the spatial dimension size of the output with this helper function:

def conv_shape(x, k=1, p=0, s=1, d=1):
    return int((x + 2*p - d*(k - 1) - 1)/s + 1)

Calling it recursively, you get the resulting spatial dimension:

>>> w = conv_shape(conv_shape(32, k=3, p=1), k=2, s=2)
>>> w = conv_shape(conv_shape(w, k=3), k=2, s=2)
>>> w = conv_shape(conv_shape(w, k=3), k=2, s=2)

>>> w
2

Since your convolutions have squared kernels and identical strides, paddings (horizontal equals vertical), the above calculations hold true for the width and the height dimensions of the tensor. Lastly, looking at the last convolution layer conv3, which has 64 filters, the resulting number of elements per batch element before your fully connected layer is: w*w*64, i.e. 256.


However, nothing stops you from calling your layers to find out the output shape!

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Flatten())

        n_channels = self.feature_extractor(torch.empty(1, 3, 32, 32)).size(-1)

        self.classifier = nn.Sequential(
            nn.Linear(n_channels, 200),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(200, 100),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(100, 10))

    def forward(self, x):
        features = self.feature_extractor(x)
        out = self.classifier(features)
        return out

model = Net()
Answered By: Ivan

Quite late to reply to this, but just for future reference in case someone lands here while searching online, and it applies specifically to Pytorch:

The already given answer is more than sufficient,and leads to a good understanding of how convolutions and pooling layers work, which is for the best in the long run.

However and specifically for pytorch there are quicker ways to solve this problem.
One can use summary method from torchinfo library and pass the model with a dummy input. This will print -along other info- a summary of the image dimensions through all of the model’s layers.

Another quicker way if one does not want to bother with the model’s architecture before the fully connected layers, is to use pytorch’s Lazy modules. These are (marked as experimental) versions of standard modules such as conv2d that infer automatically the number of input features. Therefore the shape does not need to be passed as argument while setting up the architecture.
For this case, there is a module named nn.LazyLinear which is a nn.Linear module that only needs as an argument the desired output number.

Check
https://pytorch.org/docs/stable/generated/torch.nn.modules.lazy.LazyModuleMixin.html#torch.nn.modules.lazy.LazyModuleMixin

for the whole family of modules and their limitations.

Answered By: Argotera