In Pytorch, how do you multiply a (b, c, h, w) size tensor with a tensor of size (c)

Question:

I have to normalize a tensor of size size (b, c, h, w) with two tensors of size (c) which represent the respective mean and standard deviation.

I cannot manage to figure out how to multiply a tensor of shape, let say torch.Size([1, 3, 128, 128]) with a tensor of shape torch.Size([3]).

What I want to accomplish is: take the first element of the smaller tensor and multiply the first [128×128] part of the larger tensor with it. And do this for the second element and second [128×128] tensor etc.

def normalize(img, mean, std):
    """ Normalizes an image tensor.

    # Parameters:
        @img, torch.tensor of size (b, c, h, w)
        @mean, torch.tensor of size (c)
        @std, torch.tensor of size (c)

    # Returns the normalized image
    """
    # TODO: 1. Implement normalization doing channel-wise z-score normalization.

    img * mean                              #try1: this doesn't work
    torch.mul(img.view(3,128,128), mean)    #try2: this doesn't work

    return img

Both of my attempts throw the same error: RuntimeError: The size of tensor a (128) must match the size of tensor b (3) at non-singleton dimension 3.

I imagine you could create a tensor of the needed size, fill it with the values necessary and multiply that, but I would image there is a better solution than that.

Asked By: user14604864

||

Answers:

img * mean.reshape(1,3,1,1)

Will reshape the mean tensor so that torch can understand which dimensions you are trying to multiply together.

Edit for details:
Torch reads tensor sizes from lowest to highest dimension, so it can infer some of the higher dimensions (e.g. img * mean.reshape(3,1,1) also works in your case), however you must specify the lower dimensions to either be one, or match the tensor you are trying to multiply with.

Answered By: Allie
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.