# Efficiently check whether each column in a PyTorch tensor has a corresponding reversed column

## Question:

I have a collection of tensors of common shape `(2,ncol)`. Example:

``````torch.tensor([[1, 2, 3, 7, 8], [3, 3, 1, 8, 7]], dtype=torch.long)
``````

For each tensor, I want to determine if, for each column `[[a], [b]]`, the reversed column `[[b], [a]]` is also in the tensor. For example, in this case, since `ncol` is odd, I can immediately say that this is not the case. But in this other example

``````torch.tensor([[1, 2, 3, 7, 8, 4], [3, 3, 1, 8, 7, 2]], dtype=torch.long)
``````

I would actually have to perform the check. A naive solution would be

``````test = torch.tensor([[1, 2, 3, 7, 8, 4], [3, 3, 1, 8, 7, 2]], dtype=torch.long)

def are_column_paired(matrix: torch_geometric.data.Data) -> bool:
ncol = matrix.shape
if ncol % 2 != 0:
all_paired = False
return all_paired

column_has_match = torch.zeros(ncol, dtype=torch.bool)
for i in range(ncol):
if column_has_match[i]:
continue
column = matrix[:, i]
j = i + 1
while not (column_has_match[i]) and (j <= (ncol - 1)):
if column_has_match[j]:
j = j + 1
continue
current_column = matrix[:, j]
current_column = current_column.flip(dims=)
if torch.equal(column, current_column):
column_has_match[i], column_has_match[j] = True, True
j = j + 1

all_paired = torch.all(column_has_match).item()

return all_paired
``````

But of course this is slow and possibly not pythonic. How can I write a more efficient code?

PS note that while `test` here is very small, in the actual use case I expect `ncol` to be O(10^5).

Here is one possible simple approach. It is likely not the most efficient you can get, but is much faster than your current solution. The idea is to simply check if the sorting the columns in the original and row-flipped tensors are identical. I believe the time complexity of this approach is `O(n logn)`, as opposed to `O(n^2)` in your case.

``````def are_columns_paired(matrix):
flipped_matrix = matrix.flip(dims=)
matrix_sorted = matrix[:,matrix.argsort()] # sort second row
matrix_sorted = matrix_sorted[:, matrix_sorted.sort(stable=True)] # sort first row, keeping positions in second row fixed when there is a tie
flipped_matrix = flipped_matrix[:,flipped_matrix.argsort()]
flipped_matrix = flipped_matrix[:, flipped_matrix.sort(stable=True)]
return (matrix_sorted == flipped_matrix).all()
``````

Here, for both the original and flipped matrix, we sort the columns, first based on the first row, and when there is a tie, based on the second row.

I tested both approaches on a randomly generated tensor with `ncol=2000000` and values ranging from 0 to 999999. The above code ran in about 1 second, while the approach from the question did not provide a solution even after an hour.

Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.