# Optimization of pytorch function to eliminate for loop

## Question:

lately I have been developing a function capable of dealing with tensors with dimension:

torch.Size([51, 265, 23, 23])

where the first dim is time, the second is pattern and the last 2 are pattern size.

Each individual pattern can have a maximum of 3 states: [-1,0,1], and it is considered ‘alive’
meanwhile a pattern is ‘dead’ in all other cases where it doesn’t have all 3 states.

my objective is to filter all the dead patterns by checking the last row (last time step) of the tensor.

# My current implementation (that works) is:

``````def filter_patterns(tensor_sims):

# Get the indices of the columns that need to be kept
keep_indices = torch.tensor([i for i in
range(tensor_sims.shape[1]) if
tensor_sims[-1,i].unique().numel() == 3])

# Keep only the columns that meet the condition
tensor_sims = tensor_sims[:, keep_indices]

print(f'Number of patterns: {tensor_sims.shape[1]}')
return tensor_sims
``````

Unfortunately I’m not able to get rid of the for loop.

I tried to play around with the torch.unique() function and with the parameter dim, I tried reducing the dimensions of the tensor and flattening, but nothing worked.

# Found Solution (thanks to the answer):

``````def filter_patterns(tensor_sims):
# Flatten the spatial dimensions of the last timestep
x_ = tensor_sims[-1].flatten(1)

# Create masks to identify -1, 0, and 1 conditions

# Combine the masks using logical_and

# Keep only the columns that meet the condition

print(f'Number of patterns: {tensor_sims.shape[1]}')
return tensor_sims
``````

the new implementation is extremely faster.

I don’t believe you can get away with `torch.unique` because it won’t work per column. Instead of iterating over `dim=1` you could construct three mask tensors to check for `-1`, `0`, and `1` values, respectively. To compute the resulting column mask, you can get away with some basic logic when combining the masks:

Considering you only check on the last timestep, focus on that and flatten the spatial dimensions:

``````x_ = x[-1].flatten(1)
``````

The three masks to identify `-1`, `0`, and `1` conditions can be obtained with: `x_ == -1`, `x_ == 0`, and `x_ == 1`, respectively. Combine them with `torch.logical_or`

``````mask = (x_ == -1).logical_or(x_ == 0).logical_or(x_ == 1)
``````

Finally, check that all elements are `True` across rows:

``````keep_indices = mask.all(dim=1)
``````