how to improve rotation in a spatial transformation network

Question:

I am applying a spatial transformation network to a dataset I created.
The dataset consists of boots and shoes that are slightly rotated (random rotation between 10° and 30°) as shown in the dataset images figure. I trained my model on the Fashionmnist dataset and also tested it. I expected to get images that are aligned, but I got something like this:

transformed images

this how how my CNN and stn look likes:

class STN_CNN(nn.Module):
def __init__(self):
    super(STN_CNN, self).__init__()
    self.cnn = nn.Sequential(
        nn.Conv2d(1, 10, kernel_size=3, stride=1, padding=0),
        nn.MaxPool2d(2, stride=2),
        nn.ReLU(),
        nn.Conv2d(10, 16, kernel_size=3, stride=1, padding=0),
        nn.MaxPool2d(2, stride=2),
        nn.ReLU()
    )
    self.classifier = nn.Sequential(
        nn.Linear(16*2*2, 32),
        nn.ReLU(),
        nn.Linear(32, 10)
    )
    self.localization = nn.Sequential(
        nn.Conv2d(1, 20, kernel_size=5, stride=1, padding=0),
        nn.MaxPool2d(2, stride=2),
        nn.ReLU(),
        nn.Conv2d(20, 20, kernel_size=5, stride=1, padding=0),
        nn.ReLU()
    )
    self.fc_loc = nn.Sequential(
        nn.Linear(20*8*8, 20),
        nn.ReLU(),
        nn.Linear(20, 6)
    )
    self.AvgPool = nn.AvgPool2d(2, stride=2)

    self.fc_loc[2].weight.data.zero_()
    self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

def stn(self, x):
    x_loc = self.localization(x)
    x_loc = x_loc.view(-1, 20*8*8)
    theta = self.fc_loc(x_loc)
    theta = theta.view(-1, 2, 3)

    grid = F.affine_grid(theta, x.size())
    x = F.grid_sample(x, grid)
    x = self.AvgPool(x)

    return x

def forward(self, x):
    x = self.stn(x)
    x = self.cnn(x)
    x = x.view(-1, 16*2*2)
    x = self.classifier(x)

    return x

loss and accuracy

i even trained my dataset for 100 epochs but i’m not getting any improvement
could someone just tell me how to improve the rotation part in my stn. Or Maybe if i’m doing something wrong just let me know. I’ll be really happy if someone could help.

Thank you in advance.

Asked By: zakaria14

||

Answers:

The problem was that I trained the model with a different pixel format. For example, imagine that the pixels in each image are normalized between 0 and 1 and the test image is not normalized example between 0-255. The matplotlib method allows plotting both form this has acted badly on the training.

Answered By: zakaria14