How to remove some labels of a pytorch dataset
Question:
I have a torchvision.datasets
object. I only want to keep some labels and delete the others.
For example, if my dataset is CFAR10 like this trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
I will have 10 labels. I only want to keep the first three labels and delete the others. How can I do that?
P.S:
I think I can do that by building a dataset object from scratch like this. But I’m guessing there should be a shorter way to do that:
class FilteredDataset(torch.utils.data.Dataset):
def __init__(self, dataset, desired_labels):
self.dataset = dataset
self.indices = [i for i, (_, target) in enumerate(self.dataset) if target in desired_labels]
def __getitem__(self, index):
return self.dataset[self.indices[index]]
def __len__(self):
return len(self.indices)
Answers:
Your approach is a good one. you can also define which labels you want to keep and Create a new dataset object containing only the desired labels.
desired_labels = [0, 1, 2]
trainset = datasets.CIFAR10(root='./data', train=True, download=True)
filtered_trainset = torch.utils.data.Subset(trainset, [i for i in range(len(trainset)) if trainset.targets[i] in desired_labels])
I have a torchvision.datasets
object. I only want to keep some labels and delete the others.
For example, if my dataset is CFAR10 like this trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
I will have 10 labels. I only want to keep the first three labels and delete the others. How can I do that?
P.S:
I think I can do that by building a dataset object from scratch like this. But I’m guessing there should be a shorter way to do that:
class FilteredDataset(torch.utils.data.Dataset):
def __init__(self, dataset, desired_labels):
self.dataset = dataset
self.indices = [i for i, (_, target) in enumerate(self.dataset) if target in desired_labels]
def __getitem__(self, index):
return self.dataset[self.indices[index]]
def __len__(self):
return len(self.indices)
Your approach is a good one. you can also define which labels you want to keep and Create a new dataset object containing only the desired labels.
desired_labels = [0, 1, 2]
trainset = datasets.CIFAR10(root='./data', train=True, download=True)
filtered_trainset = torch.utils.data.Subset(trainset, [i for i in range(len(trainset)) if trainset.targets[i] in desired_labels])