How to remove some labels of a pytorch dataset

Question:

I have a torchvision.datasets object. I only want to keep some labels and delete the others.

For example, if my dataset is CFAR10 like this trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True) I will have 10 labels. I only want to keep the first three labels and delete the others. How can I do that?

P.S:
I think I can do that by building a dataset object from scratch like this. But I’m guessing there should be a shorter way to do that:

class FilteredDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, desired_labels):
        self.dataset = dataset
        self.indices = [i for i, (_, target) in enumerate(self.dataset) if target in desired_labels]
        
    def __getitem__(self, index):
        return self.dataset[self.indices[index]]
    
    def __len__(self):
        return len(self.indices)
Asked By: Peyman

||

Answers:

Your approach is a good one. you can also define which labels you want to keep and Create a new dataset object containing only the desired labels.

desired_labels = [0, 1, 2]
trainset = datasets.CIFAR10(root='./data', train=True, download=True)

filtered_trainset = torch.utils.data.Subset(trainset, [i for i in range(len(trainset)) if trainset.targets[i] in desired_labels])
Answered By: Phoenix