PyTorch: bitwise OR all elements below a certain dimension

Question:

New to pytorch and tensors in general, I could use some guidance 🙂 I’ll do my best to write a correct question, but I may use terms incorrectly here and there. Feel free to correct all of this 🙂

Say I have a tensor of shape (n, 3, 3). Essentially, n matrices of 3×3. Each of these matrices contains either 0 or 1 for each cell.

What’s the best (fastest, easiest?) way to do a bitwise OR for all of these matrices?

For example, if I have 3 matrices:

0 0 1
0 0 0
1 0 0

--

1 0 0
0 0 0
1 0 1

--

0 1 1
0 1 0
1 0 1

I want the final result to be

1 1 1
0 1 0
1 0 1
Asked By: aspyct

||

Answers:

The easiest and fastest way to perform a bitwise OR operation on a tensor in PyTorch is to use the torch.bitwise_or() function. This function takes in two tensors as input and performs a bitwise OR operation element-wise. To apply the operation to all the matrices in your tensor, you can use a for loop to iterate through the first dimension of the tensor and use torch.bitwise_or() to perform the operation on each matrix individually.

import torch

# Create a tensor of shape (n, 3, 3)
n = 10
tensor = torch.randint(0, 2, (n, 3, 3))

# Initialize an empty tensor to store the result
result = torch.zeros((3, 3), dtype=torch.uint8)

# Iterate through the first dimension of the tensor
for i in range(n):
    # Perform a bitwise OR operation on each matrix
    result = torch.bitwise_or(result, tensor[i])

# Print the result
print(result)

Alternatively, you can use the reduce() function from the torch library as well, it will be more efficent

result = torch.reduce(tensor, 0, lambda x,y: torch.bitwise_or(x, y))

Both above methods will give you a single (3,3) matrix as the result of OR-ing all the submatrices in the original tensor.

Answered By: Hassan

Add all the tensors across the first dimension and check if the sum is above 0:

import torch

tensor = torch.tensor([[[0, 0, 1],
                       [0, 0, 0],
                       [1, 0, 0]],
                      [[1, 0, 0],
                       [0, 0, 0],
                       [1, 0, 1]],
                      [[0, 1, 1],
                       [0, 1, 0],
                       [1, 0, 1]]])

tensor2 = torch.sum(tensor, axis = 0) > 0
tensor2 = tensor2.to(torch.uint8)
Answered By: Michael Cao