PyTorch custom transformation with additional argument in __call__


I have a custom dataset that I want to train a neural network on. A sample of the dataset might be [1,2,3,4] and the corresponding time axis is then for example [0, 0.2, 0.4, 0.6].

This time axis is different for every sample in the dataset and is needed for certain transformations.

I only want to train the neural network on the actual data ([1,2,3,4]). Therefore in my custom transformation I need to pass in an additional time list only used for that transformation. However I have not found any example of how to accomplish this.

I have read but in their transformation the __call__ always only takes the "sample" as input like this:

def __call__(self, sample):

I could pass the time axis as part of the sample, but then wouldn’t the neural network also train on the time axis? Which I do not want.

How can I accomplish passing the time axis to the call function for a custom PyTorch transformation without training on the time data?

Asked By: oas



You control how the transformations are called in the dataset, so if you write your own dataset you can transform your sample with whatever extra data you want directly in __getitem__.

If you want to follow the model of separating your transforms from your dataset (which is probably a good practice), then you can write your dataset to expect transforms that take both sample and time-axis. Since torchvision’s built-in transforms don’t expect the time-axis you can write a wrapper to apply them only to the sample argument. One caveat is that if we want to continue using torchvision’s Compose transform then we need our transforms to take a single argument. We could write a custom compose pretty easily but it’s a bit easier IMO to just pack all the arguments into a single tuple argument.

An incomplete example (you need to fill in ... sections) might look something like this

from import Dataset, DataLoader
from torchvision import transforms, utils

class TransformWrapper:
    """ Wraps a transform that operates on only the sample """
    def __init__(self, t):
        self.t = t

    def __call__(self, data):
            data: tuple containing both sample and time_axis
            returns a tuple containing the transformed sample and original time_axis
        sample, time_axis = data
        return self.t(sample), time_axis

class CustomTransform:
    """ a custom transform dependent on time axis """
    def __init__(self, ...):

    def __call__(self, data):
        sample, time_axis = data
        new_sample = ... # some function of sample and time_axis
        return new_sample, time_axis

class MyDataset(Dataset):
    def __init__(self, root, transform=None):
            root: ...
            transform: A transform that operates on a tuple containing sample and time_index
        ... # init dataset
        self.transform = transform

    def __getitem__(self, index):
        sample, time_axis = self.get_data(index)
        if self.transform is not None:
            # transform operates on a tuple containing both sample and time_axis
            sample, time_axis = self.transform((sample, time_axis))

        # dataset doesn't need to return time_axis
        return sample

    def get_data(self, index):
        ... # load and return sample and time_axis at index

    def __len__(self):
        ... # returns length of data

# example of how to compose wrapped transforms
dataset = MyDataset(

loader = DataLoader(dataset, ...)

# train loop
for samples in loader:
Answered By: jodag
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.