How to get only specific classes from PyTorch's FashionMNIST dataset?

Question:

The FashionMNIST dataset has 10 different output classes. How can I get a subset of this dataset with only specific classes? In my case, I only want images of sneaker, pullover, sandal and shirt classes (their classes are 7,2,5 and 6 respectively).

This is how I load my dataset.

train_dataset_full = torchvision.datasets.FashionMNIST(data_folder, train = True, download = True, transform = transforms.ToTensor())

The approach I’ve followed is below.
Iterate through the dataset, one by one, then compare the 1st element (i.e. class) in the returned tuple to my required class. I’m stuck here. If the value returned is true, how can I append/add this observation to an empty dataset?

sneaker = 0
pullover = 0
sandal = 0
shirt = 0
for i in range(60000):
    if train_dataset_full[i][1] == 7:
        sneaker += 1
    elif train_dataset_full[i][1] == 2:
        pullover += 1
    elif train_dataset_full[i][1] == 5:
        sandal += 1
    elif train_dataset_full[i][1] == 6:
        shirt += 1

Now, in place of sneaker += 1, pullover += 1, sandal += 1 and shirt += 1 I want to do something like this empty_dataset.append(train_dataset_full[i]) or something similar.

If the above approach is incorrect, please suggest another method.

Asked By: Nurav Adnab

||

Answers:

You can use list comprehension to match the label. For example

idx = dataset.train_labels == 1
dataset.train_labels = dataset.train_labels[idx]

That will select only the labels you want.

Answered By: Minh-Long Luu

Finally found the answer.

dataset_full = torchvision.datasets.FashionMNIST(data_folder, train = True, download = True, transform = transforms.ToTensor())
# Selecting classes 7, 2, 5 and 6
idx = (dataset_full.targets==7) | (dataset_full.targets==2) | (dataset_full.targets==5) | (dataset_full.targets==6)
dataset_full.targets = dataset_full.targets[idx]
dataset_full.data = dataset_full.data[idx]
Answered By: Nurav Adnab

I could not use dataset.train_labels or dataset.data, so I loaded the full dataset using DataLoader with all the labels, then during training step selected the needed labels. In my case, the labels were 3 and 4. Not sure in the correctness of my method.

for epoch in range(2):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        if (labels==4)|(labels==3):

        # zero the parameter gradients
            optimizer.zero_grad()

        # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 2000))
                running_loss = 0.0

print('Finished Training')
Answered By: Saltanat Khalyk

Let me demonstrate a robust and simple solution, which simply filters samples and class attributes:

## Dataset

import torch
from torchvision import datasets
import torchvision.transforms as transforms

PATH = "data/train"

transform = transforms.Compose([transforms.Resize(256),
                            transforms.RandomCrop(224),
                            transforms.ToTensor(),
                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224,0.225])])

dataset = datasets.ImageFolder(PATH, transform=transform)
dataset.classes = ['3','8']
dataset.class_to_idx = {'3':0,'8':1}
dataset.samples = list(filter(lambda s: s[1] in [0,1], dataset.samples))

dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

Here is a fully working example with ResNet18.

Answered By: Maciej Skorski

To add on to Nurav’s answer, the target list indexing doesn’t work since dataset.targets is a list. Reindexing dataset.data works since that is a numpy array. Therefore, I had to reindex the targets as follows (taken from here):

dataset_full.targets = [dataset_full.targets[index] for index in idx]

Hope this helps anyone stuck on this.

Answered By: Adam
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.