Increase Efficiency of Function That Checks Number of Sphere Intersections

Question:

I have a function that takes in 2 sets of spheres. Each set may have different lengths. It checks all the spheres in set B that DO NOT intersect with ANY of the spheres in set A. In the picture below, I want to find the number of red spheres that do not intersect with the blue.

enter image description here

The code does the following:

  1. It takes all possible distances between pairwise combinations of coordinates
  2. It computes the distances between each of them
  3. It appends all the coordinates in b where its pairwise distances (with a) are less than twice the radius to list1
  4. If it is greater than twice the radius, then it appends the b’s coords to list2
  5. It then removes all the points from list2 that are in list1.

Here is the code:

def check_presence(comparison, safeones):
  for array in safeones:
    if (array == b).all():
      return True
  return False

def circle_intersection(a, b, r): #a (query) returns outliers of the comparison
    count = 0
    safeones = [] #safe points of the comparison
    questionableones = []
    for i in range(len(a)):
        for j in range(len(b)):
          dist = math.sqrt((a[i][0]-b[j][0])**2 + (a[i][1]-b[j][1])**2)
          if dist < 2 * r:
            safeones.append(b[j])         
          else:
            if not check_presence(b, safeones):
              questionableones.append(b[j])
    safeones = list(set([tuple(i) for i in safeones]))
    safeones = [list(i) for i in safeones]
    questionableones = list(set([tuple(i) for i in questionableones]))
    questionableones = [list(i) for i in questionableones]
    outliers = [i for i in questionableones if i not in safeones]
    
    return len(outliers)

The problem is that this function checks for ALL POSSIBLE combinations, and it is not necessary to find just the spheres that do not intersect. I am trying to figure out how to decrease computational power… something like an early stopping method that allows me to exit out of just the second for loop if we know that the sphere in set B intersects with at least one of set A’s.

I have tried:

def check_presence(comparison, safeones):
  for array in safeones:
    if (array == b).all():
      return True
  return False

def circle_intersection(a, b, r): #a (query) returns outliers of the comparison
    count = 0
    safeones = [] #safe points of the comparison
    questionableones = []
    for i in range(len(b)):
        for j in range(len(a)):
          dist = math.sqrt((b[i][0]-a[j][0])**2 + (b[i][1]-a[j][1])**2)
          print(dist)
          if dist < 2 * r:
            break     
          else:
            questionableones.append(b[i])
            
    questionableones = list(set([tuple(i) for i in questionableones]))
    questionableones = [list(i) for i in questionableones]
           
    return len(questionableones)

But this outputs a value that seems too high, and not sure why. Please help.

Asked By: Chelsea Zou

||

Answers:

For this type of problems, my tool of choice is the excellent scipy.spatial.KDTree. It will avoid the O(n^2) computation time you’d have by doing the full pairwise distance calculations (for numbers of blue and red spheres each proportional to n). In fact, it does the whole calculation for 100K blue spheres and 100K red spheres in under 140ms. With a full pairwise approach, one would need to consider up to 10_000_000_000 distances.

Assuming you have your sphere centers in two numpy arrays: blu and red:

import numpy as np
from scipy.spatial import KDTree

def find_too_close(blu, red, radius):
    tree = KDTree(blu)
    dist, idx = tree.query(red, k=1, distance_upper_bound=radius * 2)
    red_bad_idx = np.nonzero(dist < 2 * radius)[0]
    return red_bad_idx, dist, idx

That function returns a tuple of:

  • red_bad_idx: the indices of "bad" spheres (the ones that intersect at least one blue sphere);
  • dist: for each red sphere, the center-to-center distance to the closest blue sphere, of inf if that distance is greater than 2 * radius;
  • idx: the index of the closest neighbor (only valid if the corresponding dist is not inf).

Reproducible example:

def gen(n, dim=2):
    blu = np.random.uniform(0, 1, (n, dim))
    red = np.random.uniform(0, 1, (n, dim))
    return blu, red

np.random.seed(0)
blu, red = gen(20)
radius = .1

red_bad_idx, dist, idx = find_too_close(blu, red, radius)

fig, ax = plt.subplots()
# identify red spheres that intersect at least one blue sphere
# and show the closest blue sphere (black line between centers)
for xy, blu_neighbor_ix in zip(red[red_bad_idx], idx[red_bad_idx]):
    ax.add_patch(plt.Circle(xy, radius, color='r', fill=True, alpha=.1))
    ax.plot(*np.c_[xy, blu[blu_neighbor_ix]], 'k')

# plot the spheres
ax.plot(*blu.T, 'b.')
for xy in blu:
    ax.add_patch(plt.Circle(xy, radius, color='b', fill=False))
ax.plot(*red.T, 'r.')
for xy in red:
    ax.add_patch(plt.Circle(xy, radius, color='r', fill=False))
ax.set_aspect(1)

Speed

blu, red = gen(100_000)
radius = 0.1

%timeit find_too_close(blu, red, radius)
# 131 ms ± 687 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

With the uniform distribution used in gen() and a radius of 0.1, with that many spheres, all red spheres are almost guaranteed to intersect at least one blue. Indeed:

bad, _, _ = find_too_close(blu, red, radius)
>>> len(bad)
100000

Let’s make much smaller spheres and observe how it affects the number of "bad" spheres, as well as the calculation time:

bad, _, _ = find_too_close(blu, red, radius / 1000)
>>> len(bad)
1222

%timeit find_too_close(blu, red, radius / 1000)
85.2 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Additional notes

  1. If you want to compare with a full pairwise solution:

    red_bad_idx = np.unique(np.nonzero(np.linalg.norm(
        red - blu[:, None], axis=-1) < 2*radius)[1])
    
  2. With no code change, one can explore higher dimensions. Here is the 3D case:

    np.random.seed(0)
    blu, red = gen(20, 3)
    radius = 0.1
    
    red_bad_idx, dist, idx = find_too_close(blu, red, radius)
    >>> red_bad_idx
    array([ 0,  7,  8, 11, 13, 15, 16, 18, 19])
    
    # compare to finding indices via pairwise distances
    rbi = np.unique(np.nonzero(np.linalg.norm(
        red - blu[:, None], axis=-1) < 2*radius)[1])
    
    >>> np.array_equal(red_bad_idx, rbi)
    True
    
Answered By: Pierre D
Categories: questions Tags:
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.