Finding non-intersection of two pytorch tensors
Question:
Thanks everyone in advance for your help! What I’m trying to do in PyTorch is something like numpy’s setdiff1d
. For example given the below two tensors:
t1 = torch.tensor([1, 9, 12, 5, 24]).to('cuda:0')
t2 = torch.tensor([1, 24]).to('cuda:0')
The expected output should be (sorted or unsorted):
torch.tensor([9, 12, 5])
Ideally the operations are done on GPU and no back and forth between GPU and CPU. Much appreciated!
Answers:
if you don’t want to leave cuda, a workaround could be:
t1 = torch.tensor([1, 9, 12, 5, 24], device = 'cuda')
t2 = torch.tensor([1, 24], device = 'cuda')
indices = torch.ones_like(t1, dtype = torch.uint8, device = 'cuda')
for elem in t2:
indices = indices & (t1 != elem)
intersection = t1[indices]
If you don’t want a for loop this can compare all values in one go.
Also you can get the non intersection easily too
t1 = torch.tensor([1, 9, 12, 5, 24])
t2 = torch.tensor([1, 24])
# Create a tensor to compare all values at once
compareview = t2.repeat(t1.shape[0],1).T
# Intersection
print(t1[(compareview == t1).T.sum(1)==1])
# Non Intersection
print(t1[(compareview != t1).T.prod(1)==1])
tensor([ 1, 24])
tensor([ 9, 12, 5])
I came across the same problem but the proposed solutions were far too slow when using larger arrays. The following simple solution works on CPU and GPU and is significantly faster than the other proposed solutions:
combined = torch.cat((t1, t2))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]
For intersection I do:
import torch
first = torch.Tensor([1, 2, 3, 4, 5, 6])
second = torch.Tensor([7, 3, 9, 1])
intersection=first[(first.view(1, -1) == second.view(-1, 1)).any(dim=0)]
Then for the diff I would do:
diff=first[(first.view(1, -1) != second.view(-1, 1)).all(dim=0)]
Here’s a function that is similar to the numpy’s setdiff1d:
def set_diff_1d(t1, t2, assume_unique=False):
"""
Set difference of two 1D tensors.
Returns the unique values in t1 that are not in t2.
"""
if not assume_unique:
t1 = torch.unique(t1)
t2 = torch.unique(t2)
return t1[(t1[:, None] != t2).all(dim=1)]
TL;DR: just take away torch_intersect1d
in the below code snippet, if your tensors have numel() larger than ~1e+4.
Sometimes a dense pair matrix of size num_t1 * num_t2
is too large to make. Also, when t1
or t2
are huge, logical operation on dense 2D pair matrices are slow.
Based on @Olivier’s answer using torch.unique
, and extends it so that one can also get results like set(t2)-set(t1)
or set(t1)-set(t2)
, there’s a solution that only requires GPU usage of O(num_t1 + num_t2)
:
import torch
from torch.utils.benchmark import Timer
device = torch.device('cuda')
# t1 = torch.tensor([1, 9, 12, 5, 24], device=device)
# t2 = torch.tensor([1, 24, 3], device=device)
# t1 = torch.unique(torch.randint(4096, [4096], device=device))
# t2 = torch.unique(torch.randint(4096, [8192], device=device))
t1 = torch.unique(torch.randint(40960, [40960], device=device))
t2 = torch.unique(torch.randint(40960, [81920], device=device))
def torch_intersect1d(t1: torch.Tensor, t2: torch.Tensor):
# NOTE: requires t1, t2 to be unique 1D Tensor in advance.
# Method: based on unique's count
num_t1, num_t2 = t1.numel(), t2.numel()
u, inv, cnt = torch.unique(torch.cat([t1,t2]), return_counts=True, return_inverse=True)
cnt_12 = cnt[inv]
cnt_t1, cnt_t2 = cnt_12[:num_t1], cnt_12[num_t1:]
m_t1 = (cnt_t1 == 2)
inds_t1 = m_t1.nonzero()[..., 0]
inds_t1_exclusive = (~m_t1).nonzero()[..., 0]
inds_t2_exclusive = (cnt_t2 == 1).nonzero()[..., 0]
intersection = t1[inds_t1]
t1_exclusive = t1[inds_t1_exclusive]
t2_exclusive = t2[inds_t2_exclusive]
return intersection, t1_exclusive, t2_exclusive
def torch_intersect1d_dense_pair(t1: torch.Tensor, t2: torch.Tensor):
# NOTE: requires t1, t2 to be unique 1D Tensor in advance.
# Method: expands to dense 2D pair matrix
match = (t1.view(1,-1) == t2.view(-1,1))
m_t1, m_t2 = match.any(dim=0), match.any(dim=1)
inds_t1 = m_t1.nonzero()[..., 0]
inds_t1_exclusive = (~m_t1).nonzero()[..., 0]
inds_t2_exclusive = (~m_t2).nonzero()[..., 0]
intersection = t1[inds_t1]
t1_exclusive = t1[inds_t1_exclusive]
t2_exclusive = t2[inds_t2_exclusive]
return intersection, t1_exclusive, t2_exclusive
# Cross validate correctness
i1, t11, t21 = torch_intersect1d(t1, t2)
i2, t12, t22 = torch_intersect1d_dense_pair(t1, t2)
print(torch.equal(i1, i2))
print(torch.equal(t11, t12))
print(torch.equal(t21, t22))
print(Timer(
stmt="torch_intersect1d(t1, t2)",
globals={'torch_intersect1d':torch_intersect1d, 't1': t1, 't2': t2}
).blocked_autorange())
print(Timer(
stmt="torch_intersect1d_dense_pair(t1, t2)",
globals={'torch_intersect1d_dense_pair':torch_intersect1d_dense_pair, 't1': t1, 't2': t2}
).blocked_autorange())
The result of the simple use case
t1=torch.tensor([1, 9, 12, 5, 24]
, t2=torch.tensor([1, 24, 3]
:
intersection: tensor([1, 24])
t1_exclusive: tensor([9, 12, 5])
t2_exclusive: tensor([3])
The Benchmark result of torch_intersect1d
vs. torch_intersect1d_dense_pair
:
When t1
, t2
are of 3k~5k shape: 211 us vs. 307 us.
When t1
, t2
are of 30k~50k shape: 344 us vs. 6.34 ms.
Detailed benchmark (consumed time (us)):
+------------------------------+--------+--------+--------+---------+----------+--------+---------+
| tensor size | 1000 | 3162 | 10000 | 31622 | 100000 | 316227 | 1000000 |
+------------------------------+--------+--------+--------+---------+----------+--------+---------+
| torch_intersect1d | 209.44 | 290.57 | 292.31 | 310.83 | 322.11 | 488.71 | 1018.41 |
| torch_intersect1d_dense_pair | 138.30 | 186.37 | 549.01 | 4382.19 | 43741.66 | OOM | OOM |
+------------------------------+--------+--------+--------+---------+----------+--------+---------+
Thanks everyone in advance for your help! What I’m trying to do in PyTorch is something like numpy’s setdiff1d
. For example given the below two tensors:
t1 = torch.tensor([1, 9, 12, 5, 24]).to('cuda:0')
t2 = torch.tensor([1, 24]).to('cuda:0')
The expected output should be (sorted or unsorted):
torch.tensor([9, 12, 5])
Ideally the operations are done on GPU and no back and forth between GPU and CPU. Much appreciated!
if you don’t want to leave cuda, a workaround could be:
t1 = torch.tensor([1, 9, 12, 5, 24], device = 'cuda')
t2 = torch.tensor([1, 24], device = 'cuda')
indices = torch.ones_like(t1, dtype = torch.uint8, device = 'cuda')
for elem in t2:
indices = indices & (t1 != elem)
intersection = t1[indices]
If you don’t want a for loop this can compare all values in one go.
Also you can get the non intersection easily too
t1 = torch.tensor([1, 9, 12, 5, 24])
t2 = torch.tensor([1, 24])
# Create a tensor to compare all values at once
compareview = t2.repeat(t1.shape[0],1).T
# Intersection
print(t1[(compareview == t1).T.sum(1)==1])
# Non Intersection
print(t1[(compareview != t1).T.prod(1)==1])
tensor([ 1, 24])
tensor([ 9, 12, 5])
I came across the same problem but the proposed solutions were far too slow when using larger arrays. The following simple solution works on CPU and GPU and is significantly faster than the other proposed solutions:
combined = torch.cat((t1, t2))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]
For intersection I do:
import torch
first = torch.Tensor([1, 2, 3, 4, 5, 6])
second = torch.Tensor([7, 3, 9, 1])
intersection=first[(first.view(1, -1) == second.view(-1, 1)).any(dim=0)]
Then for the diff I would do:
diff=first[(first.view(1, -1) != second.view(-1, 1)).all(dim=0)]
Here’s a function that is similar to the numpy’s setdiff1d:
def set_diff_1d(t1, t2, assume_unique=False):
"""
Set difference of two 1D tensors.
Returns the unique values in t1 that are not in t2.
"""
if not assume_unique:
t1 = torch.unique(t1)
t2 = torch.unique(t2)
return t1[(t1[:, None] != t2).all(dim=1)]
TL;DR: just take away torch_intersect1d
in the below code snippet, if your tensors have numel() larger than ~1e+4.
Sometimes a dense pair matrix of size num_t1 * num_t2
is too large to make. Also, when t1
or t2
are huge, logical operation on dense 2D pair matrices are slow.
Based on @Olivier’s answer using torch.unique
, and extends it so that one can also get results like set(t2)-set(t1)
or set(t1)-set(t2)
, there’s a solution that only requires GPU usage of O(num_t1 + num_t2)
:
import torch
from torch.utils.benchmark import Timer
device = torch.device('cuda')
# t1 = torch.tensor([1, 9, 12, 5, 24], device=device)
# t2 = torch.tensor([1, 24, 3], device=device)
# t1 = torch.unique(torch.randint(4096, [4096], device=device))
# t2 = torch.unique(torch.randint(4096, [8192], device=device))
t1 = torch.unique(torch.randint(40960, [40960], device=device))
t2 = torch.unique(torch.randint(40960, [81920], device=device))
def torch_intersect1d(t1: torch.Tensor, t2: torch.Tensor):
# NOTE: requires t1, t2 to be unique 1D Tensor in advance.
# Method: based on unique's count
num_t1, num_t2 = t1.numel(), t2.numel()
u, inv, cnt = torch.unique(torch.cat([t1,t2]), return_counts=True, return_inverse=True)
cnt_12 = cnt[inv]
cnt_t1, cnt_t2 = cnt_12[:num_t1], cnt_12[num_t1:]
m_t1 = (cnt_t1 == 2)
inds_t1 = m_t1.nonzero()[..., 0]
inds_t1_exclusive = (~m_t1).nonzero()[..., 0]
inds_t2_exclusive = (cnt_t2 == 1).nonzero()[..., 0]
intersection = t1[inds_t1]
t1_exclusive = t1[inds_t1_exclusive]
t2_exclusive = t2[inds_t2_exclusive]
return intersection, t1_exclusive, t2_exclusive
def torch_intersect1d_dense_pair(t1: torch.Tensor, t2: torch.Tensor):
# NOTE: requires t1, t2 to be unique 1D Tensor in advance.
# Method: expands to dense 2D pair matrix
match = (t1.view(1,-1) == t2.view(-1,1))
m_t1, m_t2 = match.any(dim=0), match.any(dim=1)
inds_t1 = m_t1.nonzero()[..., 0]
inds_t1_exclusive = (~m_t1).nonzero()[..., 0]
inds_t2_exclusive = (~m_t2).nonzero()[..., 0]
intersection = t1[inds_t1]
t1_exclusive = t1[inds_t1_exclusive]
t2_exclusive = t2[inds_t2_exclusive]
return intersection, t1_exclusive, t2_exclusive
# Cross validate correctness
i1, t11, t21 = torch_intersect1d(t1, t2)
i2, t12, t22 = torch_intersect1d_dense_pair(t1, t2)
print(torch.equal(i1, i2))
print(torch.equal(t11, t12))
print(torch.equal(t21, t22))
print(Timer(
stmt="torch_intersect1d(t1, t2)",
globals={'torch_intersect1d':torch_intersect1d, 't1': t1, 't2': t2}
).blocked_autorange())
print(Timer(
stmt="torch_intersect1d_dense_pair(t1, t2)",
globals={'torch_intersect1d_dense_pair':torch_intersect1d_dense_pair, 't1': t1, 't2': t2}
).blocked_autorange())
The result of the simple use case
t1=torch.tensor([1, 9, 12, 5, 24]
, t2=torch.tensor([1, 24, 3]
:
intersection: tensor([1, 24])
t1_exclusive: tensor([9, 12, 5])
t2_exclusive: tensor([3])
The Benchmark result of torch_intersect1d
vs. torch_intersect1d_dense_pair
:
When t1
, t2
are of 3k~5k shape: 211 us vs. 307 us.
When t1
, t2
are of 30k~50k shape: 344 us vs. 6.34 ms.
Detailed benchmark (consumed time (us)):
+------------------------------+--------+--------+--------+---------+----------+--------+---------+
| tensor size | 1000 | 3162 | 10000 | 31622 | 100000 | 316227 | 1000000 |
+------------------------------+--------+--------+--------+---------+----------+--------+---------+
| torch_intersect1d | 209.44 | 290.57 | 292.31 | 310.83 | 322.11 | 488.71 | 1018.41 |
| torch_intersect1d_dense_pair | 138.30 | 186.37 | 549.01 | 4382.19 | 43741.66 | OOM | OOM |
+------------------------------+--------+--------+--------+---------+----------+--------+---------+