vectorizing custom python function with numpy array

Question:

Not sure if that is the correct terminology. Basically trying to take a black and white image and first transform it such that all the white pixels that border black-pixels remain white, else turn black. That part of the program works fine, and is done in find_edges. Next I need to calculate the distance from each element in the image to the closest white-pixel. Right now I am doing it by using a for-loop that is insanely slow. Is there a way to make the find_nearest_edge function written solely with numpy without the need for a for-loop to call it on each element? Thanks.

####

from PIL import Image
import numpy as np
from scipy.ndimage import binary_erosion

####

def find_nearest_edge(arr, point):
    w, h = arr.shape
    x, y = point
    xcoords, ycoords = np.meshgrid(np.arange(w), np.arange(h))

    target = np.sqrt((xcoords - x)**2 + (ycoords - y)**2)
    target[arr == 0] = np.inf

    shortest_distance = np.min(target[target > 0.0])

    return shortest_distance

def find_edges(img):
    img = img.convert('L')
    img_np = np.array(img)

    kernel = np.ones((3,3))
    edges = img_np - binary_erosion(img_np, kernel)*255

    return edges

a = Image.open('a.png')
x, y = a.size

edges = find_edges(a)

out = Image.fromarray(edges.astype('uint8'), 'L')
out.save('b.png')

dists =[]
for _x in range(x):
    for _y in range(y):
        dist = find_nearest_edge(edges,(_x,_y))
        dists.append(dist)

print(dists)

Images:

enter image description here

enter image description here

Asked By: user18615293

||

Answers:

You can use KDTree to compute distances fast.

import numpy as np
import matplotlib.pyplot as plt

from scipy.ndimage import binary_erosion
from scipy.spatial import KDTree


def find_edges(img):
    img_np = np.array(img)

    kernel = np.ones((3,3))
    edges = img_np - binary_erosion(img_np, kernel)*255

    return edges


def find_closest_distance(img):
    # NOTE: assuming input is binary image and white is any non-zero value!
    white_pixel_points = np.array(np.where(img))
    tree = KDTree(white_pixel_points.T)
    img_meshgrid = np.array(np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]))).T
    distances, _ = tree.query(img_meshgrid)
    return distances

test_image = np.zeros((200, 200))
rectangle = np.ones((30, 80))
test_image[20:50, 60:140] = rectangle
test_image[150:180, 60:140] = rectangle
test_image[60:140, 20:50] = rectangle.T
test_image[60:140, 150:180] = rectangle.T
test_image = test_image * 255
edge_image = find_edges(test_image)
distance_image = find_closest_distance(edge_image)


fig, axes = plt.subplots(1, 3, figsize=(12, 5))
axes[0].imshow(test_image, cmap='Greys_r')
axes[1].imshow(edge_image, cmap='Greys_r')
axes[2].imshow(distance_image, cmap='Greys_r')
plt.show()

enter image description here

Answered By: dankal444

You can make your code 25X faster by just changing find_nearest_edge as follows. Many other optimizations are possible, but this is the biggest bottleneck in your code.

from numba import njit
@njit
def find_nearest_edge(arr, point):
    x, y = point
    shortest_distance = np.inf
    for i in range(arr.shape[0]):
        for j in range(arr.shape[1]):
            if arr[i,j] == 0: continue
            shortest_distance = min(shortest_distance, (i-x)**2 + (j-y)**2)
    return np.sqrt(shortest_distance)
Answered By: AboAmmar