How to convert one-hot vector to label index and back in Pytorch?

Question:

How to transform vectors of labels to one-hot encoding and back in Pytorch?

The solution to the question was copied to here after having to go through the entire forum discussion, instead of just finding an easy one from googling.

Asked By: Gulzar

||

Answers:

From the Pytorch forums

import torch
import numpy as np


labels = torch.randint(0, 10, (10,))

# labels --> one-hot 
one_hot = torch.nn.functional.one_hot(labels)
# one-hot --> labels
labels_again = torch.argmax(one_hot, dim=1)

np.testing.assert_equals(labels.numpy(), labels_again.numpy())
Answered By: Gulzar

Since I can’t comment on the accepted answer, I just wanted to add that if your target does not include all classes (e.g. because you train in batches), you can specify the number of classes as argument:

# labels --> one-hot 
one_hot = torch.nn.functional.one_hot(target, num_classes=7)
Answered By: swageta