RuntimeError due to inplace operation in GAN generator architecture with skip connections

Question:

I get the following error for a GAN model I am using to perform image colorization. It uses the LAB color space as is common in image colorization. The generator generates the a ad b channels for a given L channel. The discriminator is fed all three channels after concatenation.

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64, 64, 128, 128]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I believe the error is due to the skip connections but I cannot quite put my finger on it. Any help would be appreciated!

Here is the model:

class NetGen(nn.Module):
    '''Generator'''
    def __init__(self):
        super(NetGen, self).__init__()

        self.conv1 = nn.Conv2d(1, 64, 3, stride=2, padding=1, bias=False)
        self.bnorm1 = nn.BatchNorm2d(64)
        self.relu1 = nn.LeakyReLU(0.1)

        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False)
        self.bnorm2 = nn.BatchNorm2d(128)
        self.relu2 = nn.LeakyReLU(0.1)

        self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False)
        self.bnorm3 = nn.BatchNorm2d(256)
        self.relu3 = nn.LeakyReLU(0.1)

        self.conv4 = nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False)
        self.bnorm4 = nn.BatchNorm2d(512)
        self.relu4 = nn.LeakyReLU(0.1)

        self.conv5 = nn.Conv2d(512, 512, 3, stride=2, padding=1, bias=False)
        self.bnorm5 = nn.BatchNorm2d(512)
        self.relu5 = nn.LeakyReLU(0.1)

        self.deconv6 = nn.ConvTranspose2d(512, 512, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm6 = nn.BatchNorm2d(512)
        self.relu6 = nn.ReLU()

        self.deconv7 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm7 = nn.BatchNorm2d(256)
        self.relu7 = nn.ReLU()

        self.deconv8 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm8 = nn.BatchNorm2d(128)
        self.relu8 = nn.ReLU()

        self.deconv9 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm9 = nn.BatchNorm2d(64)
        self.relu9 = nn.ReLU()

        self.deconv10 = nn.ConvTranspose2d(64, 2, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.tanh = nn.Tanh()
        

    def forward(self, x):
        h = x
        h = self.conv1(h)
        h = self.bnorm1(h)
        h = self.relu1(h) 
        pool1 = h

        h = self.conv2(h)
        h = self.bnorm2(h)
        h = self.relu2(h) 
        pool2 = h

        h = self.conv3(h) 
        h = self.bnorm3(h)
        h = self.relu3(h)
        pool3 = h

        h = self.conv4(h) 
        h = self.bnorm4(h)
        h = self.relu4(h)
        pool4 = h

        h = self.conv5(h) 
        h = self.bnorm5(h)
        h = self.relu5(h)

        h = self.deconv6(h)
        h = self.bnorm6(h)
        h = self.relu6(h) 
        h += pool4

        h = self.deconv7(h)
        h = self.bnorm7(h)
        h = self.relu7(h) 
        h += pool3

        h = self.deconv8(h)
        h = self.bnorm8(h)
        h = self.relu8(h)
        h += pool2

        h = self.deconv9(h)
        h = self.bnorm9(h)
        h = self.relu9(h)
        h += pool1

        h = self.deconv10(h)
        h = self.tanh(h) 
        return h

class NetDis(nn.Module):
    '''Discriminator'''
    def __init__(self):
        super(NetDis, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),

            nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),

            nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),

            nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),

            nn.Conv2d(512, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),


            nn.Conv2d(512, 512, 8, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),

            nn.Conv2d(512, 1, 1, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

Here is the weight init function:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

Here is the training and validation code:

class Trainer:
    def __init__(self, epochs, batch_size, learning_rate, num_workers):
        self.epochs = epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_workers = num_workers
        self.train_paths = train_paths
        self.val_paths = val_paths        
        self.real_label = 1
        self.fake_label = 0

    def train(self):             
        train_dataset = ColorizeData(paths=self.train_paths)
        train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers,pin_memory=True, drop_last = True)
        # Model
        model_G = NetGen().to(device)
        model_D = NetDis().to(device)

        model_G.apply(weights_init)
        model_D.apply(weights_init)

        optimizer_G = torch.optim.Adam(model_G.parameters(),
                             lr=self.learning_rate, betas=(0.5, 0.999),
                             eps=1e-8, weight_decay=0)
        optimizer_D = torch.optim.Adam(model_D.parameters(),
                             lr=self.learning_rate, betas=(0.5, 0.999),
                             eps=1e-8, weight_decay=0)
        
        criterion = nn.BCELoss()
        L1 = nn.L1Loss()

        model_G.train()
        model_D.train()


        # train loop
        for epoch in range(self.epochs):
            print("Starting Training Epoch " + str(epoch + 1))
            for i, data in enumerate(tqdm(train_dataloader)):                                                    
                inputs, input_ab, input_l = data
                inputs = inputs.to(device)
                input_ab = input_ab.to(device)
                input_l = input_l.to(device)


                model_D.zero_grad()
                label = torch.full((self.batch_size,), self.real_label, dtype=torch.float, device=device)
                output = model_D(torch.cat([input_l, input_ab], dim=1))
                errD_real = criterion(torch.squeeze(output), label)
                errD_real.backward()

                fake = model_G(input_l)
                label.fill_(self.fake_label)

                output = model_D(torch.cat([input_l, fake.detach()], dim=1))
                errD_fake = criterion(torch.squeeze(output), label)
                errD_fake.backward()
                errD = errD_real + errD_fake
                optimizer_D.step()

                model_G.zero_grad()
                label.fill_(self.real_label)  
                output = model_D(torch.cat([input_l, fake], dim=1))
                errG = criterion(torch.squeeze(output), label)
                errG_L1 = L1(fake.view(fake.size(0),-1), input_ab.view(input_ab.size(0),-1))
                errG = errG + 100 * errG_L1
                errG.backward()
                optimizer_G.step()  


            print(f'Training: Epoch {epoch + 1} tt Discriminator Loss: {
                errD / len(train_dataloader)}  tt Generator Loss: {
                errG / len(train_dataloader)}')
            
            if (epoch + 1) % 1 == 0:
                errD_val, errG_val, val_len = self.validate(model_D, model_G, criterion, L1)
                print(f'Validation: Epoch {epoch + 1} tt Discriminator Loss: {
                        errD_val / val_len}  tt Generator Loss: {
                        errG_val / val_len}')
                
            torch.save(model_G.state_dict(), '../Results/Model_GAN/Generator/saved_model_' + str(epoch + 1) + '.pth')
            torch.save(model_D.state_dict(), '../Results/Model_GAN/Discriminator/saved_model_' + str(epoch + 1) + '.pth')


    def validate(self, model_D, model_G, criterion, L1):

        model_G.eval()
        model_D.eval()
        with torch.no_grad():
            valid_loss = 0.0
            val_dataset = ColorizeData(paths=self.val_paths)
            val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, drop_last = True)
            for i, data in enumerate(val_dataloader):
                inputs, input_ab, input_l = data
                inputs = inputs.to(device)
                input_ab = input_ab.to(device)
                input_l = input_l.to(device)

                label = torch.full((self.batch_size,), self.real_label, dtype=torch.float, device=device)
                output = model_D(torch.cat([input_l, input_ab], dim=1))
                errD_real = criterion(torch.squeeze(output), label)

                fake = model_G(input_l)
                label.fill_(self.fake_label)
                output = model_D(torch.cat([input_l, fake.detach()], dim=1))
                errD_fake = criterion(torch.squeeze(output), label)
                
                errD = errD_real + errD_fake

                label.fill_(self.real_label)  
                output = model_D(torch.cat([input_l, fake], dim=1))
                errG = criterion(torch.squeeze(output), label)
                errG_L1 = L1(fake.view(fake.size(0),-1), input_ab.view(input_ab.size(0),-1))
                errG = errG + 100 * errG_L1

        return errD, errG, len(val_dataloader)

EDIT
As suggested by @manaclan here is the code I use to run the pipeline:

trainer = Trainer(epochs = 100, batch_size = 64, learning_rate = 0.0002, num_workers = 2)
trainer.train()

Here is the data loader:

class ColorizeData(Dataset):
    def __init__(self, paths):
        self.input_transform = T.Compose([T.ToTensor(),
                                          T.Resize(size=(256,256)),
                                          T.Grayscale(),
                                          T.Normalize((0.5), (0.5))
                                          ])
        self.lab_transform = T.Compose([T.ToTensor(),
                                          T.Resize(size=(256,256)),
                                          T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                          ])
        self.paths = paths

    def __len__(self) -> int:
        return len(self.paths)
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image = Image.open(self.paths[index]).convert("RGB")
        input_image = self.input_transform(image)
        image_lab = rgb2lab(image)
        image_lab = self.lab_transform(image_lab)
        image_l = image_lab[0, :, :]
        image_ab = image_lab[1:3, :, :]
        return (input_image.float(), image_ab.float(), image_l.float().reshape(1, 256, 256))

Here are the imports:

from typing import Tuple
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torch
import numpy as np
import os
import torch.nn as nn
import torchvision.models as models
import torchvision
import torch.nn.functional as functional
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from PIL import Image
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io

from torchvision.transforms.functional import resize

To reproduce the error, just use any dataset of color images.
I have the following code to get my train, test, and validation images from the folder "Dataset":

path = "../Dataset/"
paths = np.array(glob.glob(path + "/*.jpg"))
rand_indices = np.random.permutation(len(paths))          # Number of images in dataset
train_indices, val_indices, test_indices = rand_indices[:3600], rand_indices[3600:4000], rand_indices[4000:]
train_paths = paths[train_indices]
val_paths = paths[val_indices]
test_paths = paths[test_indices]

NOTE: I am using Google Colab, maybe this might be a potential problem? Also, I am using torch version 1.10.0+cu111.
I did use a sequential model without skip connections for the generator before this, and I did not have this error then.

Asked By: JayShreekumar

||

Answers:

Maybe try to use the output of the layers directly for the skip connections, like this:

def forward(self, x):
    h = x
    h = self.conv1(h)
    h = self.bnorm1(h)
    h1 = self.relu1(h) 

    h = self.conv2(h1)
    h = self.bnorm2(h)
    h2 = self.relu2(h) 

    h = self.conv3(h2) 
    h = self.bnorm3(h)
    h3 = self.relu3(h)

    h = self.conv4(h3) 
    h = self.bnorm4(h)
    h4 = self.relu4(h)

    h = self.conv5(h4) 
    h = self.bnorm5(h)
    h = self.relu5(h)

    h = self.deconv6(h5)
    h = self.bnorm6(h)
    h = self.relu6(h) 
    h += h4

    h = self.deconv7(h)
    h = self.bnorm7(h)
    h = self.relu7(h) 
    h += h3

    h = self.deconv8(h)
    h = self.bnorm8(h)
    h = self.relu8(h)
    h += h2

    h = self.deconv9(h)
    h = self.bnorm9(h)
    h = self.relu9(h)
    h += h1

    h = self.deconv10(h)
    h = self.tanh(h) 
    return h
Answered By: Theodor Peifer

So apparently, the problem is the inplace skip connection written as h += poolX. Writing this update out of place as h = h + poolX fixed it. h is needed for gradient calculation in some layers, so inplace modification will mess it up.

Answered By: JayShreekumar