# Optimization of pytorch function to eliminate for loop

## Question:

lately I have been developing a function capable of dealing with tensors with dimension:

torch.Size([51, 265, 23, 23])

where the first dim is time, the second is pattern and the last 2 are pattern size.

Each individual pattern can have a maximum of 3 states: [-1,0,1], and it is considered ‘alive’

meanwhile a pattern is ‘dead’ in all other cases where it doesn’t have all 3 states.

my objective is to filter all the dead patterns by checking the last row (last time step) of the tensor.

# My current implementation (that works) is:

```
def filter_patterns(tensor_sims):
# Get the indices of the columns that need to be kept
keep_indices = torch.tensor([i for i in
range(tensor_sims.shape[1]) if
tensor_sims[-1,i].unique().numel() == 3])
# Keep only the columns that meet the condition
tensor_sims = tensor_sims[:, keep_indices]
print(f'Number of patterns: {tensor_sims.shape[1]}')
return tensor_sims
```

Unfortunately I’m not able to get rid of the for loop.

I tried to play around with the torch.unique() function and with the parameter dim, I tried reducing the dimensions of the tensor and flattening, but nothing worked.

# Found Solution (thanks to the answer):

```
def filter_patterns(tensor_sims):
# Flatten the spatial dimensions of the last timestep
x_ = tensor_sims[-1].flatten(1)
# Create masks to identify -1, 0, and 1 conditions
mask_minus_one = (x_ == -1).any(dim=1)
mask_zero = (x_ == 0).any(dim=1)
mask_one = (x_ == 1).any(dim=1)
# Combine the masks using logical_and
mask =
mask_minus_one.logical_and(mask_zero).logical_and(mask_one)
# Keep only the columns that meet the condition
tensor_sims = tensor_sims[:, mask]
print(f'Number of patterns: {tensor_sims.shape[1]}')
return tensor_sims
```

the new implementation is extremely faster.

## Answers:

I don’t believe you can get away with `torch.unique`

because it won’t work per column. Instead of iterating over `dim=1`

you could construct three mask tensors to check for `-1`

, `0`

, and `1`

values, respectively. To compute the resulting column mask, you can get away with some basic logic when combining the masks:

Considering you only check on the last timestep, focus on that and flatten the spatial dimensions:

```
x_ = x[-1].flatten(1)
```

The three masks to identify `-1`

, `0`

, and `1`

conditions can be obtained with: `x_ == -1`

, `x_ == 0`

, and `x_ == 1`

, respectively. Combine them with `torch.logical_or`

```
mask = (x_ == -1).logical_or(x_ == 0).logical_or(x_ == 1)
```

Finally, check that all elements are `True`

across rows:

```
keep_indices = mask.all(dim=1)
```