Using flatten in pytorch v1.0 Sequential module

Question:

Due to my CUDA version being 8, I am using torch 1.0.0

I need to use the Flatten layer for Sequential model. Here’s my code :

import torch
import torch.nn as nn
import torch.nn.functional as F
print(torch.__version__)
# 1.0.0
from collections import OrderedDict

layers = OrderedDict()
layers['conv1'] = nn.Conv2d(1, 5, 3)
layers['relu1'] = nn.ReLU()
layers['conv2'] = nn.Conv2d(5, 1, 3)
layers['relu2'] = nn.ReLU()
layers['flatten'] = nn.Flatten()
layers['linear1'] = nn.Linear(3600, 1)
model = nn.Sequential(
layers
).cuda()

It gives me the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-38-080f7c5f5037> in <module>
      6 layers['conv2'] = nn.Conv2d(5, 1, 3)
      7 layers['relu2'] = nn.ReLU()
----> 8 layers['flatten'] = nn.Flatten()
      9 layers['linear1'] = nn.Linear(3600, 1)
     10 model = nn.Sequential(

AttributeError: module 'torch.nn' has no attribute 'Flatten'

How can I flatten my conv layer output in pytorch 1.0.0?

Asked By: user13226710

||

Answers:

Just make a new Flatten layer.

from collections import OrderedDict

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

layers = OrderedDict()
layers['conv1'] = nn.Conv2d(1, 5, 3)
layers['relu1'] = nn.ReLU()
layers['conv2'] = nn.Conv2d(5, 1, 3)
layers['relu2'] = nn.ReLU()
layers['flatten'] = Flatten()
layers['linear1'] = nn.Linear(3600, 1)
model = nn.Sequential(
layers
).cuda()
Answered By: Zabir Al Nazi

From the source: flatten method is available in the torch.tensor package in version 1.0.0.

You tried to import flatten method using torch.nn package therefore you got an attribute error.

For example:

from torch.nn import Module
from torch.tensor import Tensor

class Net(Module):
    def __init__():
        .
        .
    
    def forward(self, x):
        .
        .
        x = Tensor.flatten(x, 1)
        .
        .
        return x
Answered By: Ahx