How to calculate mean and standard deviation of each channel of cifar10 dataset with pytorch?

Question:

I am trying to calculate mean and standard deviation of each channel of cifar10 dataset. I try this code:

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor()])

cifar10_train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

cifar10_test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# creating concatenated train and test cifar10 dataset
cifar10_dataset = torch.utils.data.ConcatDataset([cifar10_train_dataset, cifar10_test_dataset])

mean = 0.
std = 0.
for images, labels in cifar10_dataset:
    for image in images:
        mean += image.mean(axis=(1, 2))
        std += image.std(axis=(1, 2))
mean /= len(cifar10_dataset) * 3
std /= len(cifar10_dataset) * 3

but i get this error:

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

How could i do it?

Asked By: Ir8_mind

||

Answers:

The utility that you are using to merge the two datasets:

torch.utils.data.ConcatDataset()

Is intended to feed the result into a Dataloader, hence it does not expose the underlying data for direct manipulation.

If you have to run first order statistics, you are better of manually joining the .data attribute of each dataset separately:


import numpy as np

print(f"Train dataset mean: {cifar10_train_dataset.data.mean(axis=(0,1,2))}")
print(f"Test dataset mean: {cifar10_test_dataset.data.mean(axis=(0,1,2))}")
print(f"Merge dataset mean: {np.vstack((cifar10_train_dataset.data, cifar10_test_dataset.data)).mean(axis=(0,1,2))}")

Which gives us:

Train dataset mean: [125.30691805 122.95039414 113.86538318]
Test dataset mean: [126.02464141 123.7085042  114.85431865]
Merge dataset mean: [125.42653861 123.07674582 114.03020576]

And is also way faster since the entire operation is vectorized by numpy.

Answered By: Neervana