Wild discrepancies between training DeepLab ResNet V3 on Google Colab versus on local machine

Question:

I am attempting to train Deeplab Resnet V3 to perform semantic segmentation on a custom dataset. I had been working on my local machine however my GPU is just a small Quadro T1000 so I decided to move my model onto Google Colab to take advantage of their GPU instances and get better results.

Whilst I get the speed increase I was hoping for, I am getting wildly different training losses on colab compared to my local machine. I have copied and pasted the exact same code so the only difference I can find would be in the dataset. I am using the exact same dataset except for the one on colab is a copy of the local dataset on Google Drive. I have noticed that Drive orders file differently on Windows but I can’t see how this is a problem since I randomly shuffle the dataset. I understand that these random splitting can cause small differences in the outputs however a difference of about 10x in the training losses doesn’t make sense.

I have also tried running the version on colab with different random seeds, different batch sizes, different train_test_split parameters, and changing the optimizer from SGD to Adam, however, this still causes the model to converge very early at a loss of around 0.5.

Here is my code:

import torch
from torch.utils import data
from torchvision import transforms
from customdatasets import SegmentationDataSet
import pathlib
from sklearn.model_selection import train_test_split
from customtransforms import Compose, AlbuSeg2d, DenseTarget
from customtransforms import MoveAxis, Normalize01, Resize
import albumentations
import matplotlib.pyplot as plt
import time
import GPUtil



def get_filenames_of_path(path: pathlib.Path, ext: str = '*'):
    """Returns a list of files in a directory/path. Uses pathlib."""
    filenames = [file for file in path.glob(ext) if file.is_file()]
    return filenames


if __name__ == '__main__':

    root = pathlib.Path.cwd() / 'train'
    inputs = get_filenames_of_path(root / 'input')
    targets = get_filenames_of_path(root / 'target')

    

# training transformations and augmentations
    transforms_training = Compose([
        Resize(input_size=(128, 128, 3), target_size=(128, 128)),
        AlbuSeg2d(albu=albumentations.HorizontalFlip(p=0.5)),
        MoveAxis(),
        Normalize01()
    ])
# validation transformations
    transforms_validation = Compose([
        Resize(input_size=(128, 128, 3), target_size=(128, 128)),
        MoveAxis(),
        Normalize01()
    ])
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    random_seed = 142
    train_size = 0.8

    inputs_train, inputs_valid = train_test_split(
        inputs,
        random_state=random_seed,
        train_size=train_size,
        shuffle=True)
    targets_train, targets_valid = train_test_split(
        targets,
        random_state=random_seed,
        train_size=train_size,
        shuffle=True)

    dataset_train = SegmentationDataSet(inputs=inputs_train,
                                    targets=targets_train,
                                    transform=transforms_training,
                                    device=device)

    dataset_valid = SegmentationDataSet(inputs=inputs_valid,
                                    targets=targets_valid,
                                    transform=transforms_validation,
                                    device=device)


    dataloader_training = data.DataLoader(dataset=dataset_train,
                                      batch_size=15,
                                      shuffle=True,
                                      num_workers=4,
                                      pin_memory=True)

    dataloader_validation = data.DataLoader(dataset=dataset_valid,
                                        batch_size=15,
                                        shuffle=True,
                                        num_workers=4,
                                        pin_memory=True)


    model = torch.hub.load('pytorch/vision:v0.6.0', 'deeplabv3_resnet101', pretrained=False)


    criterion = torch.nn.CrossEntropyLoss()

    model = model.to(device)

    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.99)


    epochs = 10
    steps = 0
    running_loss = 0
    print_every = 10
    train_losses, valid_losses = [], []

    start_time = time.time()
    prev_time = time.time()


    for epoch in range(epochs):
        #Training
        for inputs, labels in dataloader_training:
            steps += 1
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device,non_blocking=True)
            optimizer.zero_grad()
            logps = model(inputs)
            loss = criterion(logps['out'], labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if steps % print_every == 0:
                train_losses.append(running_loss / len(dataloader_training))
                epoch_time = time.time()
                elasped_time = epoch_time - prev_time
                prev_time = epoch_time
                print(f"Epoch {epoch + 1}/{epochs}.. "
                    f"Train loss: {running_loss / print_every:.3f}.. "
                    f"Elapsed time: {elasped_time}")

                running_loss = 0
                model.train()
        # Evaluation
        valid_loss = 0
        accuracy = 0
        model.eval()
        with torch.no_grad():
            for inputs, labels in dataloader_validation:
                inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
                logps = model.forward(inputs)
                batch_loss = criterion(logps['out'], labels)
                valid_loss += batch_loss.item()

                ps = torch.exp(logps['out'])
                top_p, top_class = ps.topk(1, dim=1)
                equals = top_class == labels.view(*top_class.shape)
                accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
        valid_losses.append(valid_loss / len(dataloader_validation))
        print(f"Epoch {epoch + 1}/{epochs}.. "
            f"Validation loss: {valid_loss / len(dataloader_training):.3f}.. "
            f"Validation accuracy: {accuracy / len(dataloader_training):.3f} ")
        model.train()
    torch.save(model, 'model.pth')

    end_time = time.time()
    total_time = end_time - start_time
    print("Total Time: ", total_time)
    plt.plot(train_losses, label='Training loss')
    plt.plot(valid_losses, label='Validation loss')
    plt.legend(frameon=False)
    plt.show()

This is the output for one epoch on Colab:

Epoch 1/10.. Train loss: 2.080.. Elapsed time: 12.156640768051147
Epoch 1/10.. Train loss: 1.231.. Elapsed time: 8.76858925819397
Epoch 1/10.. Train loss: 1.051.. Elapsed time: 8.315532445907593
Epoch 1/10.. Train loss: 0.890.. Elapsed time: 8.249168634414673
Epoch 1/10.. Train loss: 0.839.. Elapsed time: 8.248667478561401
Epoch 1/10.. Train loss: 0.807.. Elapsed time: 8.120820999145508
Epoch 1/10.. Train loss: 0.742.. Elapsed time: 8.298616886138916
Epoch 1/10.. Train loss: 0.726.. Elapsed time: 8.170734167098999
Epoch 1/10.. Train loss: 0.677.. Elapsed time: 8.221246004104614
Epoch 1/10.. Train loss: 0.698.. Elapsed time: 8.124614000320435
Epoch 1/10.. Train loss: 0.675.. Elapsed time: 8.197462558746338
Epoch 1/10.. Train loss: 0.682.. Elapsed time: 8.263437509536743
Epoch 1/10.. Train loss: 0.626.. Elapsed time: 8.156179189682007
Epoch 1/10.. Train loss: 0.632.. Elapsed time: 8.268096446990967
Epoch 1/10.. Train loss: 0.616.. Elapsed time: 8.214547872543335
Epoch 1/10.. Train loss: 0.585.. Elapsed time: 8.31475019454956
Epoch 1/10.. Train loss: 0.598.. Elapsed time: 8.388074398040771
Epoch 1/10.. Train loss: 0.626.. Elapsed time: 8.179292440414429
Epoch 1/10.. Train loss: 0.612.. Elapsed time: 8.252359390258789
Epoch 1/10.. Train loss: 0.592.. Elapsed time: 8.284745693206787
Epoch 1/10.. Train loss: 0.597.. Elapsed time: 8.31213927268982
Epoch 1/10.. Train loss: 0.566.. Elapsed time: 8.164374113082886
Epoch 1/10.. Train loss: 0.556.. Elapsed time: 8.300082206726074
Epoch 1/10.. Train loss: 0.568.. Elapsed time: 8.26304841041565
Epoch 1/10.. Train loss: 0.572.. Elapsed time: 8.309881448745728
Epoch 1/10.. Train loss: 0.586.. Elapsed time: 8.211671352386475
Epoch 1/10.. Train loss: 0.586.. Elapsed time: 8.321797609329224
Epoch 1/10.. Train loss: 0.535.. Elapsed time: 8.318871021270752
Epoch 1/10.. Train loss: 0.543.. Elapsed time: 8.152915239334106
Epoch 1/10.. Train loss: 0.569.. Elapsed time: 8.251380205154419
Epoch 1/10.. Train loss: 0.526.. Elapsed time: 8.29153847694397
Epoch 1/10.. Train loss: 0.565.. Elapsed time: 8.15071702003479
Epoch 1/10.. Train loss: 0.542.. Elapsed time: 8.253364562988281
Epoch 1/10.. Validation loss: 0.182.. Validation accuracy: 0.271 

And here is the output on my local machine:

Epoch 1/10.. Train loss: 2.932.. Elapsed time: 32.148621797561646
Epoch 1/10.. Train loss: 1.852.. Elapsed time: 14.120505809783936
Epoch 1/10.. Train loss: 0.887.. Elapsed time: 14.210048198699951
Epoch 1/10.. Train loss: 0.618.. Elapsed time: 14.23294186592102
Epoch 1/10.. Train loss: 0.549.. Elapsed time: 14.212541103363037
Epoch 1/10.. Train loss: 0.519.. Elapsed time: 14.047481775283813
Epoch 1/10.. Train loss: 0.506.. Elapsed time: 14.060708045959473
Epoch 1/10.. Train loss: 0.347.. Elapsed time: 14.301624059677124
Epoch 1/10.. Train loss: 0.399.. Elapsed time: 13.9844491481781
Epoch 1/10.. Train loss: 0.361.. Elapsed time: 13.957871913909912
Epoch 1/10.. Train loss: 0.305.. Elapsed time: 14.164010763168335
Epoch 1/10.. Train loss: 0.296.. Elapsed time: 14.001536846160889
Epoch 1/10.. Train loss: 0.298.. Elapsed time: 14.019971132278442
Epoch 1/10.. Train loss: 0.271.. Elapsed time: 13.951345443725586
Epoch 1/10.. Train loss: 0.252.. Elapsed time: 14.037938594818115
Epoch 1/10.. Train loss: 0.283.. Elapsed time: 13.944657564163208
Epoch 1/10.. Train loss: 0.299.. Elapsed time: 13.977224826812744
Epoch 1/10.. Train loss: 0.219.. Elapsed time: 13.941975355148315
Epoch 1/10.. Train loss: 0.242.. Elapsed time: 13.936140060424805
Epoch 1/10.. Train loss: 0.244.. Elapsed time: 13.942122459411621
Epoch 1/10.. Train loss: 0.216.. Elapsed time: 13.960899114608765
Epoch 1/10.. Train loss: 0.186.. Elapsed time: 13.956881523132324
Epoch 1/10.. Train loss: 0.241.. Elapsed time: 13.944581985473633
Epoch 1/10.. Train loss: 0.203.. Elapsed time: 13.934357404708862
Epoch 1/10.. Train loss: 0.189.. Elapsed time: 13.938358306884766
Epoch 1/10.. Train loss: 0.181.. Elapsed time: 13.944468021392822
Epoch 1/10.. Train loss: 0.186.. Elapsed time: 13.946297407150269
Epoch 1/10.. Train loss: 0.164.. Elapsed time: 13.940366744995117
Epoch 1/10.. Train loss: 0.165.. Elapsed time: 13.938241720199585
Epoch 1/10.. Train loss: 0.176.. Elapsed time: 14.015569925308228
Epoch 1/10.. Train loss: 0.165.. Elapsed time: 14.019208669662476
Epoch 1/10.. Train loss: 0.175.. Elapsed time: 14.149503469467163
Epoch 1/10.. Train loss: 0.159.. Elapsed time: 14.128302097320557
Epoch 1/10.. Train loss: 0.155.. Elapsed time: 13.935027837753296
Epoch 1/10.. Train loss: 0.137.. Elapsed time: 13.937382221221924
Epoch 1/10.. Train loss: 0.127.. Elapsed time: 13.929635524749756
Epoch 1/10.. Train loss: 0.133.. Elapsed time: 13.935472011566162
Epoch 1/10.. Train loss: 0.152.. Elapsed time: 13.922808647155762
Epoch 1/10.. Validation loss: 0.032.. Validation accuracy: 0.239

I won’t paste more than this since it’s long and takes a while to run but by the end of the 3rd epoch, the loss on the Colab model is still bouncing around 0.5 whereas locally it reaches 0.02.

If anyone could help me resolve this issue it would be greatly appreciated.

Asked By: James Heaton

||

Answers:

I fixed this problem by unzipping the training data to Google Drive and reading the files from there instead of using the Colab command to unzip the folder to my workspace directly. I have absolutely no idea why this was causing the problem; a quick visual inspection at the images and their corresponding tensors looks fine, but I can’t go through each of the 6,000 or so images to check every one.
If anyone knows why this was causing a problem, please let me know!

Answered By: James Heaton