Efficiently check whether each column in a PyTorch tensor has a corresponding reversed column


I have a collection of tensors of common shape (2,ncol). Example:

torch.tensor([[1, 2, 3, 7, 8], [3, 3, 1, 8, 7]], dtype=torch.long)

For each tensor, I want to determine if, for each column [[a], [b]], the reversed column [[b], [a]] is also in the tensor. For example, in this case, since ncol is odd, I can immediately say that this is not the case. But in this other example

torch.tensor([[1, 2, 3, 7, 8, 4], [3, 3, 1, 8, 7, 2]], dtype=torch.long)

I would actually have to perform the check. A naive solution would be

test = torch.tensor([[1, 2, 3, 7, 8, 4], [3, 3, 1, 8, 7, 2]], dtype=torch.long)

def are_column_paired(matrix: torch_geometric.data.Data) -> bool:
    ncol = matrix.shape[1]
    if ncol % 2 != 0:
        all_paired = False
        return all_paired

    column_has_match = torch.zeros(ncol, dtype=torch.bool)
    for i in range(ncol):
        if column_has_match[i]:
        column = matrix[:, i]
        j = i + 1
        while not (column_has_match[i]) and (j <= (ncol - 1)):
            if column_has_match[j]:
                j = j + 1
            current_column = matrix[:, j]
            current_column = current_column.flip(dims=[0])
            if torch.equal(column, current_column):
                column_has_match[i], column_has_match[j] = True, True
            j = j + 1

    all_paired = torch.all(column_has_match).item()

    return all_paired

But of course this is slow and possibly not pythonic. How can I write a more efficient code?

PS note that while test here is very small, in the actual use case I expect ncol to be O(10^5).

Asked By: DeltaIV



Here is one possible simple approach. It is likely not the most efficient you can get, but is much faster than your current solution. The idea is to simply check if the sorting the columns in the original and row-flipped tensors are identical. I believe the time complexity of this approach is O(n logn), as opposed to O(n^2) in your case.

def are_columns_paired(matrix):
    flipped_matrix = matrix.flip(dims=[0])
    matrix_sorted = matrix[:,matrix[1].argsort()] # sort second row
    matrix_sorted = matrix_sorted[:, matrix_sorted[0].sort(stable=True)[1]] # sort first row, keeping positions in second row fixed when there is a tie
    flipped_matrix = flipped_matrix[:,flipped_matrix[1].argsort()]
    flipped_matrix = flipped_matrix[:, flipped_matrix[0].sort(stable=True)[1]]
    return (matrix_sorted == flipped_matrix).all()

Here, for both the original and flipped matrix, we sort the columns, first based on the first row, and when there is a tie, based on the second row.

I tested both approaches on a randomly generated tensor with ncol=2000000 and values ranging from 0 to 999999. The above code ran in about 1 second, while the approach from the question did not provide a solution even after an hour.

Answered By: GoodDeeds
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.