Is the DataLoader object an iterable object?

Question:

This is the code

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math

# creating a custom class for our dataset, which inherits from Dataset.
class WineDataset(Dataset):

    # this function is used for data loading
    def __init__(self):
      # data loading
      xy = np.loadtxt('./wine.csv', delimiter=',', dtype=np.float32, skiprows=1)
      self.x = torch.from_numpy(xy[:, 1:])  # the first column is the output label
      self.y = torch.from_numpy(xy[:, [0]]) # n_samples, 1
      self.n_samples = xy.shape[0]

    # this function allows indexing in our dataset
    def __getitem__(self, index):
      return self.x[index], self.y[index] # the function returns a tuple.

    # this allows us to call len on our dataset.
    def __len__(self):
      return self.n_samples

dataset = WineDataset()
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2)

dataiter = iter(dataloader)
data = next(dataiter)
features, labels = data
print(features, labels)

My question is that since, we can already call the enumerate method directly on the dataloader, does it mean that the dataloader object is an iterable ?
If this is true, then calling iter(dataloader) would be the same as creating an iterator object from an iterator object ?

I’m a bit confused about this please help me out.

I need to know what the enumerate method is doing behind the scenes when dataloader is passed as an argument. Also need to know what iter(dataloader) is doing.

Asked By: Suryansh Sinha

||

Answers:

Iterable is something which implements __iter__ method. Iterator is something which implements __next__ method. Both iter() and enumerate() call the __iter___ method of the class. for example

class A: # this is an iterable
    def __iter__(self):
        print ('iter called at A')
        return B()
    
class B: # this is an iterator
    def __next__(self):
        print( 'next called at B')
        return 1

Note, any object of class B is an iterator because it implements __next__ but its not an iterable becuase it doesn’t have __iter__ method. Similarly, any object of class A is an iterable but not an iterator.

Run it,

a = A()

create an iterator

b = iter(a)
print(f'{type(b)=}')
"""
iter called at A
type(b)=<class '__main__.B'>
"""

calling next() on iterator b

next(b)
"""
next called at B
1
"""

can’t call next() on iterable a

next(a)
"""
TypeError: 'A' object is not an iterator
"""

We can do a for loop on a

for i in a:
    print(i)
    break
"""
iter called at A
next called at B
1
"""

Can’t do a for loop on b

for i in b:
    print(i)
    break
"""
TypeError: 'B' object is not iterable
"""

Now, call enumerate

c = enumerate(a)
print(f'{type(c)=}')
"""
iter called at A
type(c)=<class 'enumerate'>
"""

Can do a for loop on c as well as call next()

next(c)
"""
next called at B
(0, 1)
"""
for i in c:
    print(i)
    break
"""
next called at B
(1, 1)
"""

So the enumerate class is both an iterator and an iterable because it has both __iter__ and __next__ methods. You can check this by calling dir(c).

When we call enumerate on a Dataloader, its __iter__ method is called. Looking at the signature of __iter__ function in pytorch source code:

class DataLoader(Generic[T_co]):
.
.
    def __iter__(self) -> '_BaseDataLoaderIter':

This _BaseDataLoaderIter class implements both __iter__ and __next__, so its both an iterable and an iterator.

class _BaseDataLoaderIter(object):
.
.
    def __iter__(self) -> '_BaseDataLoaderIter':
        return self

    def __next__(self) -> Any:
    .
    .
        return data

So you can call both enumerate() and iter() on Dataloader and can even do for loops. You can check source code in your python directly somewhere at
..Libsite-packagestorchutilsdatadataloader.py

Answered By: Nelson aka SpOOKY
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.