Finding non-intersection of two pytorch tensors

Question:

Thanks everyone in advance for your help! What I’m trying to do in PyTorch is something like numpy’s setdiff1d. For example given the below two tensors:

t1 = torch.tensor([1, 9, 12, 5, 24]).to('cuda:0')
t2 = torch.tensor([1, 24]).to('cuda:0')

The expected output should be (sorted or unsorted):

torch.tensor([9, 12, 5])

Ideally the operations are done on GPU and no back and forth between GPU and CPU. Much appreciated!

Asked By: Shiki.E

||

Answers:

if you don’t want to leave cuda, a workaround could be:

t1 = torch.tensor([1, 9, 12, 5, 24], device = 'cuda')
t2 = torch.tensor([1, 24], device = 'cuda')
indices = torch.ones_like(t1, dtype = torch.uint8, device = 'cuda')
for elem in t2:
    indices = indices & (t1 != elem)  
intersection = t1[indices]  
Answered By: ntipakos

If you don’t want a for loop this can compare all values in one go.

Also you can get the non intersection easily too

t1 = torch.tensor([1, 9, 12, 5, 24])
t2 = torch.tensor([1, 24])

# Create a tensor to compare all values at once
compareview = t2.repeat(t1.shape[0],1).T

# Intersection
print(t1[(compareview == t1).T.sum(1)==1])
# Non Intersection
print(t1[(compareview != t1).T.prod(1)==1])
tensor([ 1, 24])
tensor([ 9, 12,  5])
Answered By: Harry

I came across the same problem but the proposed solutions were far too slow when using larger arrays. The following simple solution works on CPU and GPU and is significantly faster than the other proposed solutions:

combined = torch.cat((t1, t2))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]
Answered By: Olivier

For intersection I do:

import torch
first = torch.Tensor([1, 2, 3, 4, 5, 6])
second = torch.Tensor([7, 3, 9, 1])
intersection=first[(first.view(1, -1) == second.view(-1, 1)).any(dim=0)]

Then for the diff I would do:

diff=first[(first.view(1, -1) != second.view(-1, 1)).all(dim=0)]
Answered By: user2648582

Here’s a function that is similar to the numpy’s setdiff1d:

def set_diff_1d(t1, t2, assume_unique=False):
    """
    Set difference of two 1D tensors.
    Returns the unique values in t1 that are not in t2.

    """
    if not assume_unique:
        t1 = torch.unique(t1)
        t2 = torch.unique(t2)
    return t1[(t1[:, None] != t2).all(dim=1)]
Answered By: Andreas K.

TL;DR: just take away torch_intersect1d in the below code snippet, if your tensors have numel() larger than ~1e+4.

Sometimes a dense pair matrix of size num_t1 * num_t2 is too large to make. Also, when t1 or t2 are huge, logical operation on dense 2D pair matrices are slow.

Based on @Olivier’s answer using torch.unique, and extends it so that one can also get results like set(t2)-set(t1) or set(t1)-set(t2), there’s a solution that only requires GPU usage of O(num_t1 + num_t2):

import torch
from torch.utils.benchmark import Timer
device = torch.device('cuda')

# t1 = torch.tensor([1, 9, 12, 5, 24], device=device)
# t2 = torch.tensor([1, 24, 3], device=device)

# t1 = torch.unique(torch.randint(4096, [4096], device=device))
# t2 = torch.unique(torch.randint(4096, [8192], device=device))

t1 = torch.unique(torch.randint(40960, [40960], device=device))
t2 = torch.unique(torch.randint(40960, [81920], device=device))

def torch_intersect1d(t1: torch.Tensor, t2: torch.Tensor):
    # NOTE: requires t1, t2 to be unique 1D Tensor in advance.
    # Method: based on unique's count
    num_t1, num_t2 = t1.numel(), t2.numel()
    u, inv, cnt = torch.unique(torch.cat([t1,t2]), return_counts=True, return_inverse=True)
    
    cnt_12 = cnt[inv]
    cnt_t1, cnt_t2 = cnt_12[:num_t1], cnt_12[num_t1:]
    m_t1 = (cnt_t1 == 2)
    inds_t1 = m_t1.nonzero()[..., 0]
    inds_t1_exclusive = (~m_t1).nonzero()[..., 0]
    inds_t2_exclusive = (cnt_t2 == 1).nonzero()[..., 0]

    intersection = t1[inds_t1]
    t1_exclusive = t1[inds_t1_exclusive]
    t2_exclusive = t2[inds_t2_exclusive]
    return intersection, t1_exclusive, t2_exclusive

def torch_intersect1d_dense_pair(t1: torch.Tensor, t2: torch.Tensor):
    # NOTE: requires t1, t2 to be unique 1D Tensor in advance.
    # Method: expands to dense 2D pair matrix
    match = (t1.view(1,-1) == t2.view(-1,1))
    m_t1, m_t2 = match.any(dim=0), match.any(dim=1)
    inds_t1 = m_t1.nonzero()[..., 0]
    inds_t1_exclusive = (~m_t1).nonzero()[..., 0]
    inds_t2_exclusive = (~m_t2).nonzero()[..., 0]
    
    intersection = t1[inds_t1]
    t1_exclusive = t1[inds_t1_exclusive]
    t2_exclusive = t2[inds_t2_exclusive]
    return intersection, t1_exclusive, t2_exclusive

# Cross validate correctness
i1, t11, t21 = torch_intersect1d(t1, t2)
i2, t12, t22 = torch_intersect1d_dense_pair(t1, t2)
print(torch.equal(i1, i2))
print(torch.equal(t11, t12))
print(torch.equal(t21, t22))

print(Timer(
    stmt="torch_intersect1d(t1, t2)", 
    globals={'torch_intersect1d':torch_intersect1d, 't1': t1, 't2': t2}
).blocked_autorange())

print(Timer(
    stmt="torch_intersect1d_dense_pair(t1, t2)", 
    globals={'torch_intersect1d_dense_pair':torch_intersect1d_dense_pair, 't1': t1, 't2': t2}
).blocked_autorange())

The result of the simple use case
t1=torch.tensor([1, 9, 12, 5, 24], t2=torch.tensor([1, 24, 3]:

intersection: tensor([1, 24])
t1_exclusive: tensor([9, 12, 5])
t2_exclusive: tensor([3])

The Benchmark result of torch_intersect1d vs. torch_intersect1d_dense_pair:
When t1, t2 are of 3k~5k shape: 211 us vs. 307 us.
When t1, t2 are of 30k~50k shape: 344 us vs. 6.34 ms.

Detailed benchmark (consumed time (us)):

+------------------------------+--------+--------+--------+---------+----------+--------+---------+
|         tensor size          |  1000  |  3162  | 10000  |  31622  |  100000  | 316227 | 1000000 |
+------------------------------+--------+--------+--------+---------+----------+--------+---------+
|      torch_intersect1d       | 209.44 | 290.57 | 292.31 |  310.83 |  322.11  | 488.71 | 1018.41 |
| torch_intersect1d_dense_pair | 138.30 | 186.37 | 549.01 | 4382.19 | 43741.66 |  OOM   |   OOM   |
+------------------------------+--------+--------+--------+---------+----------+--------+---------+
Answered By: Jianfei Guo
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.