Why torch.nn.Conv2d() divides the image into 9 parts?

Question:

Sorry for the stupid question but, why torch.nnnConv2d() divides the image into 9 parts?

import torch
from torch import nn
import cv2

img = cv2.imread("image_game/eldenring 2022-12-14 19-29-50.png")
cv2.imshow('input', img)
size = img.shape #  (720, 1280, 3)
img = img.reshape((1, img.shape[2], size[0], size[1]))
img = torch.tensor(img, dtype=torch.float32)  #  torch.Size([1, 3, 720, 1280])

c1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=2, stride=1)
img = c1(img)

size = img.shape
img = img.reshape((size[2], size[3], size[1])).detach().numpy()
cv2.imshow('output', img)
cv2.waitKey(0)

return this:

input image:
input
output image:
output

I want this:

gif

enter image description here

edit:

When I use

c1 = nn.Conv2d(1, 1, kernel_size=(3, 3), padding=2, stride=1)

instead

c1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=2, stride=1)

I get what I want, but how to do it when there are more channels?

Asked By: Karol Szymczak

||

Answers:

I’m sorry that the description of the question was unclear.
Javier TG solved my problem

The issue is with using reshape to permute the axes -> opencv’s imread
gives an array of size (H, W, 3), so to get the pytorch’s (1, 3, H, W)
representation, transpose (in numpy) and permute (in pytorch) should
be used instead. Try substituting the first reshape with img =
img[None].transpose(0, 3, 1, 2), and the last reshaping with img =
img[0].permute(1, 2, 0).detach().numpy() – Javier TG

I thought the problem is in the nn.Conv2d() function but i just wrong transposed data.

Corrected code:

import torch
from torch import nn
import cv2

img = cv2.imread("image_game/eldenring 2022-12-14 19-29-50.png")
cv2.imshow('input', img)  # (720, 1280, 3)
img = img[None].transpose(0, 3, 1, 2)
img = torch.as_tensor(img).float()  # torch.Size([1, 3, 720, 1280])

c1 = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=1, stride=1)
img = c1(img)

img = img[0].permute(1, 2, 0).detach().numpy()  # (720, 1280, 3)
cv2.imshow('output', img)
cv2.waitKey(0)
Answered By: Karol Szymczak