Why Pytorch Dataset class does not returning list?


I am trying to use torch.utils.Dataset on a custom dataset. In my dataset, in a single row I have a list of 10 images like as follow:

| word | images | gold_image |

I expect to return batch from dataloader like this, with batch_size=4

('word_1', 'word_2', 'word_3', 'word_4'), ([image_1,image_2,image_3],[image_4,image_5,image_6],[image_7,image_8,image_9], [image_10,image11,image_12]), ([0,0,1],[1,0,0],[0,1,0],[0,1,0])

But, I am getting like this,

('word_1', 'word_2', 'word_3', 'word_4'), [(image_1,image_2,image_3,image_4),(image_5,image_6,image_7,image_8), (image_9,image_10,image_11,image_12)], [(0,1,0,0),(1,0,0,0),(0,1,0,1)]

Here is my code:

class ImageTextDataset(Dataset):
    def __init__(self, data_dir, train_df, tokenizer, feature_extractor, data_type,device, text_augmentation=False):
        self.data_dir = data_dir
        if data_type == "train":
            # this is for the original train set of the task
            # reshape all images to size [1440,1810]
            self.tokenizer = tokenizer
            self.transforms = transforms.Compose([transforms.Resize([512,512]),transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
            self.all_image_names = list(train_df['images'])
            self.keywords = list(train_df['word'])
            self.context = list(train_df['description'])
            self.gold_images = list(train_df['gold_image'])

    def __len__(self):
        return len(self.context)

    def __getitem__(self, idx):

        context = self.context[idx]
        # print(context)
        keyword = self.keywords[idx]
        #loading images
        label = []
        images = self.all_image_names[idx]
        image = []
        for i, img in enumerate(images):
          path = os.path.join(self.data_dir, "trial_images_v1", img)
          img = Image.open(path)
          if img.mode != "RGB":
              img = img.convert('RGB')
          img = self.transforms(img)
          label.append(1.0) if img == self.gold_images[idx] else label.append(0.0)

        # sample = {'context':context, 'images': images, 'label': label}

        return (context, image, label)

I can’t figure it out what is the issue.
Can anyone help?


Asked By: Shantanu Nath



The DataLoader collates the output of your dataset into batches using the default collate function (implemented in torch/utils/data/_utils/collate.py). What you’re observing is the expected behavior when a dataset returns sequence type objects (like lists).

If you want the dataloader to collate your data differently, then you can provide a custom collate function via the collate_fn argument of DataLoader.

You can read more about collation and custom collate functions in the documentation.

The following is a simple example of using a custom collate function that I believe accomplishes what you want, though you may need to change it a bit if it’s not exactly what you need.

import torch
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate

class FakeDataset:
    """ Simple fake dataset for demonstration """
    def __getitem__(self, index):
        context = f'context_{index}'
        images = []
        for i in range(3):
            images.append(torch.full((2, 5, 5), index, dtype=torch.float))
        label = [0.0, 0.0, 0.0]
        label[index % 3] = 1.0

        return context, images, label

    def __len__(self):
        return 100

def my_collate_fn(batch):
    """ batch[list]: each entry assumed to be a tuple returned from FakeDataset.__getitem__ """
    contexts = default_collate([b[0] for b in batch])  # default_collate not actually necessary here
    images = [default_collate(b[1]) for b in batch]
    labels = [default_collate(b[2]) for b in batch]

    return contexts, images, labels

# define dataloader to use custom collate function
loader = DataLoader(FakeDataset(), batch_size=4, collate_fn=my_collate_fn)

# get one batch from the dataloader for demonstration
contexts, images, labels = next(iter(loader))

print('contexts =', contexts)
print('images (sizes) =', [t.shape for t in images])
print('labels =', labels)

which prints

contexts = ['context_0', 'context_1', 'context_2', 'context_3']
images (sizes) = [torch.Size([3, 2, 5, 5]), torch.Size([3, 2, 5, 5]), torch.Size([3, 2, 5, 5]), torch.Size([3, 2, 5, 5])]
labels = [tensor([1., 0., 0.], dtype=torch.float64), tensor([0., 1., 0.], dtype=torch.float64), tensor([0., 0., 1.], dtype=torch.float64), tensor([1., 0., 0.], dtype=torch.float64)]

Note that we make use of PyTorch’s default_collate function to avoid having to rewrite that logic.

Answered By: jodag