Pytorch : TypeError when I call my data_transformation function inside where I define my train_dataset object

Question:

When I try to make an object from my get_data function :

   train = get_data(root ="My_train_path",
                     transform = data_transforms[TRAIN] )

it returns an TypeError: 'function' object is not subscriptable.


data_dir = 'my_dataset_dir'

TEST  = 'test'
TRAIN = 'train'
VAL   = 'val'
def data_transforms(phase):
    if phase == TRAIN:
        transform = A.Compose([
            A.CLAHE(clip_limit=4.0, p=0.7),
            A.CoarseDropout(max_height=8, max_width=8, max_holes=8, p=0.5),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
            
        ])
        
    if phase == VAL:
        transform = A.Compose([
            A.Resize(height=256,width=256),
            A.CenterCrop(height=224,width=224),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])
    
    if phase == TEST:
        transform = A.Compose([
            A.Resize(height=256,width=256),
            A.CenterCrop(height=224,width=224),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ])        
        
    return transform

def get_data(root,transform):

  image_dataset = CustomImageFolder(root=".",
                                    transform = transform,
                                                         )
  return image_dataset





def make_loader(dataset, batch_size,shuffle,num_workers):
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=batch_size, 
                                         shuffle=shuffle,
                                         pin_memory=True, num_workers=num_workers)
    return loader
Asked By: akshat nayak

||

Answers:

The problem is exactly as the error message says it: data_transforms is a function you have defined, and you want to call it with the training phase as the argument. However, you are erroneously subscripting the function, with your use of square brackets ([]). To fix this, replace the square brackets with parentheses (()), as is done for a function call.

That is,

train = get_data(root ="My_train_path",
                 transform = data_transforms(TRAIN) )
Answered By: GoodDeeds
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.