Incorrect format of labels in custom dataset for multi-label classification

Question:

I’m trying to implement a custom Dataset for multi-label classification. That is, one element may have multiple classes simultaneously. I tried returning the one-hot encoded representation or the class indices from the dataset directly, but neither of them works.

  • One-hot encoded produces RuntimeError: expected scalar type Long but found Float when calculating the loss.
  • Returning labels produces IndexError: Target 1 is out of bounds. when calculating the loss.

Here’s a dummy implementation:

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

n_classes = 3


class Data(Dataset):
    def __getitem__(self, index):
        return torch.tensor([[0, 0, 0, 0.0]]), torch.tensor([0.0] * n_classes)
        # return torch.tensor([[0, 0, 0, 0.0]]), torch.tensor(range(n_classes))

    def __len__(self):
        return 10


data_loader = DataLoader(Data())
model = nn.Sequential(nn.Linear(4, n_classes), nn.ReLU())

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

model.train()
for epoch in range(3):
    print('EPOCH {}:'.format(epoch))
    for inputs, labels in data_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

I couldn’t really find any documentation or tutorials for implementing this sort of dataset, and the base Dataset docs are quite terse as well. Am I missing something, or doing it all wrong?

Asked By: Felix

||

Answers:

As the error says you the labels should have type of long. Therefore, you should convert the labels to long type by:

loss = loss_fn(outputs, labels.long())

Keep in mind that: if you are using one-hot encoding with nn.CrossEntropyLoss then you are not doing it in the right way. the target in nn.CrossEntropyLoss must contain class indices

Answered By: A.Mounir
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.