PyTorch nn.module won't unbatch operations

Question:

I have a nn.Module whose forward function takes in two inputs. Inside the function, I multiply one of the inputs x1 by a set of trainable parameters, and then concatenate them with the other input x2.

class ConcatMe(nn.Module):
    def __init__(self, pad_len, emb_size):
        super(ConcatMe, self).__init__()
        self.W = nn.Parameter(torch.randn(pad_len, emb_size).to(DEVICE), requires_grad=True)
        self.emb_size = emb_size
     
    def forward(self, x1: Tensor, x2: Tensor):
        cat = self.W * torch.reshape(x2, (1, -1, 1))
        return torch.cat((x1, cat), dim=-1)

From my understanding, one is supposed to be able to write operations in PyTorch’s nn.Modules like we would for inputs with a batch size of 1. For some reason, this is not the case. I’m getting an error that indicates that PyTorch is still accounting for batch_size.

x1 =  torch.randn(100,2,512)
x2 = torch.randint(10, (2,1))
concat = ConcatMe(100, 512)
concat(x1, x2)

-----------------------------------------------------------------------------------
File "/home/my/file/path.py, line 0, in forward
    cat = self.W * torch.reshape(x2, (1, -1, 1))
RuntimeError: The size of tensor a (100) must match the size of tensor b (2) at non-singleton dimension 1

I made a for loop to patch the issue as shown below:

class ConcatMe(nn.Module):
    def __init__(self, pad_len, emb_size):
        super(ConcatMe, self).__init__()
        self.W = nn.Parameter(torch.randn(pad_len, emb_size).to(DEVICE), requires_grad=True)
        self.emb_size = emb_size
     
    def forward(self, x1: Tensor, x2: Tensor):
        batch_size = x2.shape[0]
        cat = torch.ones(x1.shape).to(DEVICE)

        for i in range(batch_size):
            cat[:, i, :] = self.W * x2[i]

        return torch.cat((x1, cat), dim=-1)

but I feel like there’s a more elegant solution. Does it have something to do with the fact that I’m creating parameters inside nn.Module? If so, what solution can I implement that doesn’t require a for loop.

Asked By: mehsheenman

||

Answers:

From my understanding, one is supposed to be able to write operations in PyTorch’s nn.Modules like we would for inputs with a batch size of 1.

I’m not sure where you got this assumption, it is definitely not true – on the contrary: you always need to write them in a way that they can handle the general case of an arbitrary batch dimension.

Judging from your second implementation it seems like you’re trying to multiply two tensors with incompatible dimensions. So in order to fix that you’d have to define

        self.W = torch.nn.Parameter(torch.randn(pad_len, 1, emb_size), requires_grad=True)

To understand things like that better it would help to learn about broadcasting.

Answered By: flawr