PyTorch custom transformation with additional argument in __call__

Question:

I have a custom dataset that I want to train a neural network on. A sample of the dataset might be [1,2,3,4] and the corresponding time axis is then for example [0, 0.2, 0.4, 0.6].

This time axis is different for every sample in the dataset and is needed for certain transformations.

I only want to train the neural network on the actual data ([1,2,3,4]). Therefore in my custom transformation I need to pass in an additional time list only used for that transformation. However I have not found any example of how to accomplish this.

I have read https://pytorch.org/tutorials/beginner/data_loading_tutorial.html but in their transformation the __call__ always only takes the "sample" as input like this:

def __call__(self, sample):

I could pass the time axis as part of the sample, but then wouldn’t the neural network also train on the time axis? Which I do not want.

How can I accomplish passing the time axis to the call function for a custom PyTorch transformation without training on the time data?

Asked By: oas

||

Answers:

You control how the transformations are called in the dataset, so if you write your own dataset you can transform your sample with whatever extra data you want directly in __getitem__.

If you want to follow the model of separating your transforms from your dataset (which is probably a good practice), then you can write your dataset to expect transforms that take both sample and time-axis. Since torchvision’s built-in transforms don’t expect the time-axis you can write a wrapper to apply them only to the sample argument. One caveat is that if we want to continue using torchvision’s Compose transform then we need our transforms to take a single argument. We could write a custom compose pretty easily but it’s a bit easier IMO to just pack all the arguments into a single tuple argument.

An incomplete example (you need to fill in ... sections) might look something like this

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils


class TransformWrapper:
    """ Wraps a transform that operates on only the sample """
    def __init__(self, t):
        self.t = t

    def __call__(self, data):
        """
            data: tuple containing both sample and time_axis
            returns a tuple containing the transformed sample and original time_axis
        """
        sample, time_axis = data
        return self.t(sample), time_axis


class CustomTransform:
    """ a custom transform dependent on time axis """
    def __init__(self, ...):
        ...

    def __call__(self, data):
        sample, time_axis = data
        new_sample = ... # some function of sample and time_axis
        return new_sample, time_axis


class MyDataset(Dataset):
    def __init__(self, root, transform=None):
        """
            root: ...
            transform: A transform that operates on a tuple containing sample and time_index
        """
        ... # init dataset
        self.transform = transform

    def __getitem__(self, index):
        sample, time_axis = self.get_data(index)
        if self.transform is not None:
            # transform operates on a tuple containing both sample and time_axis
            sample, time_axis = self.transform((sample, time_axis))

        # dataset doesn't need to return time_axis
        return sample

    def get_data(self, index):
        ... # load and return sample and time_axis at index

    def __len__(self):
        ... # returns length of data


# example of how to compose wrapped transforms
dataset = MyDataset(
    root=...,
    transform=transforms.Compose([
        TransformWrapper(transforms.Rescale(256)),
        TransformWrapper(transforms.RandomCrop(224)),
        CustomTransform(...),
        TransformWrapper(transforms.ToTensor())
    ]))

loader = DataLoader(dataset, ...)

# train loop
for samples in loader:
    ...
Answered By: jodag
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.