PyTorch : How to apply the same random transformation to multiple image?

Question:

I am writing a simple transformation for a dataset which contains many pairs of images. As a data augmentation, I want to apply some random transformation for each pair but the images in that pair should be transformed in the same way.
For example, given a pair of two images A and B, if A is flipped horizontally, B must be flipped horizontally as A. Then the next pair C and D should be differently transformed from A and B but C and D are transformed in the same way. I am trying that in the way below

import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")

transform = transforms.RandomChoice(
    [transforms.RandomHorizontalFlip(), 
     transforms.RandomVerticalFlip()]
)
random.seed(0)
display(transform(img_a))
display(transform(img_b))

random.seed(1)
display(transform(img_c))
display(transform(img_d))

Yet、 the above code does not choose the same transformation and as I tested, it is dependent on the number of times transform is called.

Is there any way to force transforms.RandomChoice to use the same transform when specified?

Asked By: TFC

||

Answers:

I dont know of a function to fix the random output.
maybe try a different logic, like creating the randomization yourself to be able to reuse the same transformation.
logic:

  • generate a random number
  • based on the number apply a transformation on both images
  • generate another random number
  • do the same for the other two images
    try this:
import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")

if random.random() > 0.5:
        image_a_flipped = transforms.functional_pil.vflip(img_a)
        image_b_flipped = transforms.functional_pil.vflip(img_b)
else:
    image_a_flipped = transforms.functional_pil.hflip(img_a)
    image_b_flipped = transforms.functional_pil.hflip(img_b)

if random.random() > 0.5:
        image_c_flipped = transforms.functional_pil.vflip(img_c)
        image_d_flipped = transforms.functional_pil.vflip(img_d)
else:
    image_c_flipped = transforms.functional_pil.hflip(img_c)
    image_d_flipped = transforms.functional_pil.hflip(img_d)
    
display(image_a_flipped)
display(image_b_flipped)

display(image_c_flipped)
display(image_d_flipped)
Answered By: Salman Hammad

Usually a workaround is to apply the transform on the first image, retrieve the parameters of that transform, then apply with a deterministic transform with those parameters on the remaining images. However, here RandomChoice does not provide an API to get the parameters of the applied transform since it involves a variable number of transforms.
In those cases, I usually implement an overwrite to the original function.

Looking at the torchvision implementation, it’s as simple as:

class RandomChoice(RandomTransforms):
    def __call__(self, img):
        t = random.choice(self.transforms)
        return t(img)

Here are two possible solutions.

  1. You can either sample from the transform list on __init__ instead of on __call__:

    import random
    import torchvision.transforms as T
    
    class RandomChoice(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.t = random.choice(self.transforms)
    
        def __call__(self, img):
            return self.t(img)
    

    So you can do:

    transform = T.RandomChoice([
         T.RandomHorizontalFlip(), 
         T.RandomVerticalFlip()
    ])
    display(transform(img_a)) # both img_a and img_b will
    display(transform(img_b)) # have the same transform
    
    transform = T.RandomChoice([
        T.RandomHorizontalFlip(), 
        T.RandomVerticalFlip()
    ])
    display(transform(img_c)) # both img_c and img_d will
    display(transform(img_d)) # have the same transform
    

  1. Or better yet, transform the images in batch:

    import random
    import torchvision.transforms as T
    
    class RandomChoice(torch.nn.Module):
        def __init__(self, transforms):
           super().__init__()
           self.transforms = transforms
    
        def __call__(self, imgs):
            t = random.choice(self.transforms)
            return [t(img) for img in imgs]
    

    Which allows to do:

    transform = T.RandomChoice([
         T.RandomHorizontalFlip(), 
         T.RandomVerticalFlip()
    ])
    
    img_at, img_bt = transform([img_a, img_b])
    display(img_at) # both img_a and img_b will
    display(img_bt) # have the same transform
    
    img_ct, img_dt = transform([img_c, img_d])
    display(img_ct) # both img_c and img_d will
    display(img_dt) # have the same transform
    
Answered By: Ivan

Simply, take the randomization part out of PyTorch into an if statement.
Below code uses vflip. Similarly for horizontal or other transforms.

import random
import torchvision.transforms.functional as TF

if random.random() > 0.5:
    image = TF.vflip(image)
    mask  = TF.vflip(mask)

This issue has been discussed in PyTorch forum. Several solutions’ pros and cons were discussed on the official GitHub repository page.
PyTorch maintainers have suggested this simple approach.

Do not use torchvision.transforms.RandomVerticalFlip(p=1). Use torchvision.transforms.functional.vflip

Functional transforms give you fine-grained control of the transformation pipeline. As opposed to the transformations above, functional transforms don’t contain a random number generator for their parameters. That means you have to specify/generate all parameters, but you can reuse the functional transform.

Answered By: Abhi25t

I realize the OP requested a solution using torchvision and I think @Ivan’s answer does a good job addressing this.

However, for those not tied to a specific augmentation library, I wanted to point out that Albumentations appears to handle these kind of situations nicely in a native fashion by allowing the user to pass multiple source images, boxes, etc into the same transform. The return is structured as a dict

import albumentations as A

transform = A.Compose(
    transforms=[
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5)],
    additional_targets={'image0': 'image', 'image1': 'image'}
)
transformed = transform(image=image, image0=image0, image1=image1)

Now you can access transformed['image0'], transformed['image1'], etc and all of them will have random parameters applied

Answered By: Addison Klinke

Referencing Random transforms for both input and target? I think this is probably the cleanest way to do it. Save the random state before applying any transformation and the just restore it for each consequent call

t = transforms.RandomRotation(degrees=360)
state = torch.get_rng_state()
x = t(x)
torch.set_rng_state(state)
y = t(y)
Answered By: Ivan Gonzalez

I think I have a simple solution:
If the images are concatenated, the transformations are applied to all of them identically:

import torch
import torchvision.transforms as T

# Create two fake images (identical for test purposes):
image = torch.randn((3, 128, 128))
target = image.clone()

# This is the trick (concatenate the images):
both_images = torch.cat((image.unsqueeze(0), target.unsqueeze(0)),0)

# Apply the transformations to both images simultaneously:
transformed_images = T.RandomRotation(180)(both_images)

# Get the transformed images:
image_trans = transformed_images[0]
target_trans = transformed_images[1]

# Compare the transformed images:
torch.all(image_trans == target_trans).item()

>> True
Answered By: Mario Galindo
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.