Moving member tensors with module.to() in PyTorch

Question:

I am building a Variational Autoencoder (VAE) in PyTorch and have a problem writing device agnostic code. The Autoencoder is a child of nn.Module with an encoder and decoder network, which are too. All weights of the network can be moved from one device to another by calling net.to(device).

The problem I have is with the reparametrization trick:

encoding = mu + noise * sigma

The noise is a tensor of the same size as mu and sigma and saved as a member variable of the autoencoder module. It is initialized in the constructor and resampled in-place each training step. I do it that way to avoid constructing a new noise tensor each step and pushing it to the desired device. Additionally, I want to fix the noise in the evaluation. Here is the code:

class VariationalGenerator(nn.Module):
    def __init__(self, input_nc, output_nc):
        super(VariationalGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        embedding_size = 128

        self._train_noise = torch.randn(batch_size, embedding_size)
        self._eval_noise = torch.randn(1, embedding_size)
        self.noise = self._train_noise

        # Create encoder
        self.encoder = Encoder(input_nc, embedding_size)
        # Create decoder
        self.decoder = Decoder(output_nc, embedding_size)

    def train(self, mode=True):
        super(VariationalGenerator, self).train(mode)
        self.noise = self._train_noise

    def eval(self):
        super(VariationalGenerator, self).eval()
        self.noise = self._eval_noise

    def forward(self, inputs):
        # Calculate parameters of embedding space
        mu, log_sigma = self.encoder.forward(inputs)
        # Resample noise if training
        if self.training:
            self.noise.normal_()
        # Reparametrize noise to embedding space
        inputs = mu + self.noise * torch.exp(0.5 * log_sigma)
        # Decode to image
        inputs = self.decoder(inputs)

        return inputs, mu, log_sigma

When I now move the autoencoder to the GPU with net.to('cuda:0') I get an error in forwarding because the noise tensor is not moved.

I don’t want to add a device parameter to the constructor, because then it is still not possible to move it to another device later. I also tried to wrap the noise into nn.Parameter so that it is affected by net.to(), but that gives an error from the optimizer, as the noise is flagged as requires_grad=False.

Anyone has a solution to move all of the modules with net.to()?

Asked By: tilman151

||

Answers:

Use this:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Now for both the model and every tensor you use

net.to(device)
input = input.to(device)
Answered By: Ran Elgiser

After some more trial and error I found two methods:

  1. Use Buffers: By replacing self._train_noise = torch.randn(batch_size, embedding_size) with self.register_buffer('_train_noise', torch.randn(batch_size, embedding_size) the noise tensor is added to the module as a buffer. This lets net.to(device) affect it, too. Additionally the tensor is now part of the state_dict.
  2. Override net.to(device): Using this the noise stays out of the state_dict.

    def to(device):
        new_self = super(VariationalGenerator, self).to(device)
        new_self._train_noise = new_self._train_noise.to(device)
        new_self._eval_noise = new_self._eval_noise.to(device)
    
        return new_self
    
Answered By: tilman151

A better version of tilman151’s second approach is probably to override _apply, rather than to. That way net.cuda(), net.float(), etc will all work as well, since those all call _apply rather than to (as can be seen in the source, which is simpler than you might think):

def _apply(self, fn):
    super(VariationalGenerator, self)._apply(fn)
    self._train_noise = fn(self._train_noise)
    self._eval_noise = fn(self._eval_noise)
    return self
Answered By: Danica

By using this you may apply the same arguments to your tensors and the module

def to(self, **kwargs):
    module = super(VariationalGenerator, self).to(**kwargs)
    module._train_noise = self._train_noise.to(**kwargs)
    module._eval_noise = self._eval_noise.to(**kwargs)

    return module
Answered By: Tobias

You can use nn.Module buffers and parameters – both are considered when calling .to(device) and moved to the device.
Parameters are being updated by optimizer (so they need requires_grad=True), buffers are not.

So in your case, I’d write constructor as:

    def __init__(self, input_nc, output_nc):
        super(VariationalGenerator, self).__init__()

        self.input_nc = input_nc
        self.output_nc = output_nc
        embedding_size = 128

        # --- CHANGED LINES ---
        self.register_buffer('_train_noise', torch.randn(batch_size, embedding_size))
        self.register_buffer('_eval_noise', torch.randn(1, embedding_size))
        # --- CHANGED LINES ---

        self.noise = self._train_noise

        # Create encoder
        self.encoder = Encoder(input_nc, embedding_size)
        # Create decoder
        self.decoder = Decoder(output_nc, embedding_size)

Answered By: lopisan