Optimization of pytorch function to eliminate for loop

Question:

lately I have been developing a function capable of dealing with tensors with dimension:

torch.Size([51, 265, 23, 23])

where the first dim is time, the second is pattern and the last 2 are pattern size.

Each individual pattern can have a maximum of 3 states: [-1,0,1], and it is considered ‘alive’
meanwhile a pattern is ‘dead’ in all other cases where it doesn’t have all 3 states.

my objective is to filter all the dead patterns by checking the last row (last time step) of the tensor.

My current implementation (that works) is:

def filter_patterns(tensor_sims):

   # Get the indices of the columns that need to be kept
   keep_indices = torch.tensor([i for i in 
   range(tensor_sims.shape[1]) if 
   tensor_sims[-1,i].unique().numel() == 3])

   # Keep only the columns that meet the condition
   tensor_sims = tensor_sims[:, keep_indices]

   print(f'Number of patterns: {tensor_sims.shape[1]}')
   return tensor_sims

Unfortunately I’m not able to get rid of the for loop.

I tried to play around with the torch.unique() function and with the parameter dim, I tried reducing the dimensions of the tensor and flattening, but nothing worked.

Found Solution (thanks to the answer):

def filter_patterns(tensor_sims):
   # Flatten the spatial dimensions of the last timestep
   x_ = tensor_sims[-1].flatten(1)

   # Create masks to identify -1, 0, and 1 conditions
   mask_minus_one = (x_ == -1).any(dim=1)
   mask_zero = (x_ == 0).any(dim=1)
   mask_one = (x_ == 1).any(dim=1)

   # Combine the masks using logical_and
   mask = 
   mask_minus_one.logical_and(mask_zero).logical_and(mask_one)

   # Keep only the columns that meet the condition
   tensor_sims = tensor_sims[:, mask]

   print(f'Number of patterns: {tensor_sims.shape[1]}')
   return tensor_sims

the new implementation is extremely faster.

Asked By: Fabrizio Brown

||

Answers:

I don’t believe you can get away with torch.unique because it won’t work per column. Instead of iterating over dim=1 you could construct three mask tensors to check for -1, 0, and 1 values, respectively. To compute the resulting column mask, you can get away with some basic logic when combining the masks:

Considering you only check on the last timestep, focus on that and flatten the spatial dimensions:

x_ = x[-1].flatten(1)

The three masks to identify -1, 0, and 1 conditions can be obtained with: x_ == -1, x_ == 0, and x_ == 1, respectively. Combine them with torch.logical_or

mask = (x_ == -1).logical_or(x_ == 0).logical_or(x_ == 1)

Finally, check that all elements are True across rows:

keep_indices = mask.all(dim=1)
Answered By: Ivan