Tensorflow version of Pytorch Transforms

Question:

I have the following code that I use to prepare images before performing inference in a model:

def image_loader(transform, image_name):
    image = Image.open(image_name)
    #transform
    image = transform(image).float()
    image = torch.tensor(image)
    image = image.unsqueeze(0)
    return image

data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

I’ve converted the model into a Tensorflow model, however, I’m unsure how I would do similar transformations to images before inference since there doesn’t seem to be a or equivalent. Any advice?

Asked By: ConnorLloyd

||

Answers:

Here is some pointer, in you have

from torchvision import transforms
from PIL import Image 
import torch 

def image_loader(transform, image_name):
    image = Image.open(image_name).convert('RGB')
    image = transform(image).float()
    image = torch.tensor(image)
    image = image.unsqueeze(0)
    return image

data_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# check: visualize 
i = image_loader(data_transforms, '/content/1.png')
i.shape

plt.figure(figsize=(25,10))
subplot(121); imshow(np.array(i[0]).transpose(1, 2, 0)); 

And in , you can achieve this as follows

def transform(image, mean, std):
    for channel in range(3):
        image[:,:,channel] = (image[:,:,channel] - mean[channel]) / std[channel]
    return image

def image_loader(image_name):
    image = Image.open(image_name).convert('RGB')
    image = transform(np.array(image)/255, 
                       mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
    image = tf.cast(image, tf.float32)
    image = tf.expand_dims(image, 0)
    return image 

# check: visualize 
i = image_loader('/content/1.png')
i.shape 

plt.figure(figsize=(25,10))
subplot(121); imshow(i[0]); 

This should output the same. Note, in the second case, we define the transform function, from another OP, here, it’s fine, however, you can also check tf. keras…Normalization, see this answer for details.

Answered By: M.Innat