Can't use plt to show an image from the CIFAR-10 dataset in Google Colab
Question:
I’m trying to show the images part of the CIFAR-10 dataset but for some reason plt
shows me an axes image instead of the actual image that I want to see.
from os import lseek
from cs231n.data_utils import load_CIFAR10
import matplotlib.pyplot as plt
Xtr, Ytr, Xte, Yte = load_CIFAR10('cs231n/datasets/cifar-10-batches-py')
# print(Xtr[0].shape) Shape is (32, 32, 3) RGB image.
plt.imshow(Xtr[0])
In the docs it says that a shape of (M, N, 3)
is ok for RGB images so I don’t know why it doesn’t show it. Any ideas?
Answers:
As discussed in the comments, the issue is due to the data type of the values in the Xtr[0]
variable. Given an (M, N, 3) array, which Xtr[0]
is, the imshow
function requires the array to either be a float array containing values between 0 and 1, or an integer array containing values between 0 and 255. In this case, Xtr[0]
is a float array with values between 0 and 255, which is not liked by imshow
. The simplest solution is to just pass this array as integers to imshow
, e.g.,
plt.imshow(Xtr[0].astype(int))
I’m trying to show the images part of the CIFAR-10 dataset but for some reason plt
shows me an axes image instead of the actual image that I want to see.
from os import lseek
from cs231n.data_utils import load_CIFAR10
import matplotlib.pyplot as plt
Xtr, Ytr, Xte, Yte = load_CIFAR10('cs231n/datasets/cifar-10-batches-py')
# print(Xtr[0].shape) Shape is (32, 32, 3) RGB image.
plt.imshow(Xtr[0])
In the docs it says that a shape of (M, N, 3)
is ok for RGB images so I don’t know why it doesn’t show it. Any ideas?
As discussed in the comments, the issue is due to the data type of the values in the Xtr[0]
variable. Given an (M, N, 3) array, which Xtr[0]
is, the imshow
function requires the array to either be a float array containing values between 0 and 1, or an integer array containing values between 0 and 255. In this case, Xtr[0]
is a float array with values between 0 and 255, which is not liked by imshow
. The simplest solution is to just pass this array as integers to imshow
, e.g.,
plt.imshow(Xtr[0].astype(int))