'DataLoader' object does not support indexing
Question:
I have downloaded the ImageNet dataset via this pytorch api by setting download=True. But I cannot iterate through the dataloader.
The error says “‘DataLoader’ object does not support indexing”
trainset = torch.utils.data.DataLoader(
datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train',
download=False))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1)
I tried a simple approach I just tried to run the following,
trainloader[0]
In the root directory, the pattern is
root/
train/
n01440764/
n01443537/
n01443537_2.jpg
The docs in the official website doesnt say anything else. https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet
What am I doing wrong ?
Answers:
The input dataset to torch.utils.data.DataLoader()
should be of type torch.utils.data.Dataset
, not torch.utils.data.DataLoader
, which is what you are doing in above code.
So, your above code should be:
trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
split='train',
download=False)
trainloader = torch.utils.data.DataLoader(trainset,
batch_size=1,
shuffle=False,
num_workers=1)
For more details, check the official torch documentation here.
Well, the answer is pretty simple (besides error mentioned in the other answer).
DataLoader
has no __getitem__
method (see in the source code for yourself).
It is used for iterating, not random access, over data (or batches of data). If you want to access specific element you should use torch.utils.data.Dataset
, in your case:
trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', )
trainset[0]
Getting a batch
If you want to get a batch you may iterate over it and break afterwards:
for batch in dataloader:
print(batch) # or anything else you want to do
break
DataLoader
creates random indices in default or specified way (see samplers), hence there is no __getitem__
as it wouldn’t make sense for this object.
You may also inherit from the DataLoader
and create your own __getitem__
function doing what you want (more complicated though).
Full example
# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', download=True)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False)
for batch in trainloader:
print(batch)
break
Above should print the first batch whatever is inside.
Solution
input_transform = standard_transforms.Compose([
transforms.Resize((255,255)), # to Make sure all the
transforms.CenterCrop(224), # imgs are at the same size
transforms.ToTensor()
])
# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
split='train', download=False, transform = input_transform)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=False)
for batch_idx, data in enumerate(trainloader, 0):
x, y = data
break
I ended up with this dirty solution:
def Dataloader_by_Index(data_loader, target=0):
for index, data in enumerate(data_loader):
if index == target:
return data
return None
fifth_element = Dataloader_by_Index(my_data_loader, target=4)
I have downloaded the ImageNet dataset via this pytorch api by setting download=True. But I cannot iterate through the dataloader.
The error says “‘DataLoader’ object does not support indexing”
trainset = torch.utils.data.DataLoader(
datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train',
download=False))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1)
I tried a simple approach I just tried to run the following,
trainloader[0]
In the root directory, the pattern is
root/
train/
n01440764/
n01443537/
n01443537_2.jpg
The docs in the official website doesnt say anything else. https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet
What am I doing wrong ?
The input dataset to torch.utils.data.DataLoader()
should be of type torch.utils.data.Dataset
, not torch.utils.data.DataLoader
, which is what you are doing in above code.
So, your above code should be:
trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
split='train',
download=False)
trainloader = torch.utils.data.DataLoader(trainset,
batch_size=1,
shuffle=False,
num_workers=1)
For more details, check the official torch documentation here.
Well, the answer is pretty simple (besides error mentioned in the other answer).
DataLoader
has no __getitem__
method (see in the source code for yourself).
It is used for iterating, not random access, over data (or batches of data). If you want to access specific element you should use torch.utils.data.Dataset
, in your case:
trainset = torchvision.datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', )
trainset[0]
Getting a batch
If you want to get a batch you may iterate over it and break afterwards:
for batch in dataloader:
print(batch) # or anything else you want to do
break
DataLoader
creates random indices in default or specified way (see samplers), hence there is no __getitem__
as it wouldn’t make sense for this object.
You may also inherit from the DataLoader
and create your own __getitem__
function doing what you want (more complicated though).
Full example
# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/', split='train', download=True)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False)
for batch in trainloader:
print(batch)
break
Above should print the first batch whatever is inside.
Solution
input_transform = standard_transforms.Compose([
transforms.Resize((255,255)), # to Make sure all the
transforms.CenterCrop(224), # imgs are at the same size
transforms.ToTensor()
])
# torch.utils.data.Dataset object
trainset = datasets.ImageNet('/media/farshid/DataStore/temp/Imagenet/',
split='train', download=False, transform = input_transform)
# torch.utils.data.DataLoader object
trainloader =torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=False)
for batch_idx, data in enumerate(trainloader, 0):
x, y = data
break
I ended up with this dirty solution:
def Dataloader_by_Index(data_loader, target=0):
for index, data in enumerate(data_loader):
if index == target:
return data
return None
fifth_element = Dataloader_by_Index(my_data_loader, target=4)