Replace colors in image by closest color in palette using numpy

Question:

I have a list of colors, and I have a function closest_color(pixel, colors) where it compares the given pixels’ RGB values with my list of colors, and it outputs the closest color from the list.

I need to apply this function to a whole image. When I try to use it pixel by pixel, (by using 2 nested for-loops) it is slow. Is there a better way to achieve this with numpy?

Asked By: stackmodern

||

Answers:

1. Option: Single image evaluation (slow)

Pros

 - any palette any time (flexible)

Cons

 - slow
 - memory for large number of colors in palette
 - not good for batch processing

2. Option: Batch processing (super fast)

Pros
 - super fast (50ms per image), independent of palette size
 - low memory, independent of image size or pallete size
 - ideal for batch processing if palette doesnt change
 - simple code 
Cons
 - requires creation of color cube (once, up to 3 minutes)
 - color cube can contain only one palette

Requirements
 - color cube requires 1.5mb of space on disk in form of compressed np matrix

Option 1:

take image, create pallete object with same size as image, calculate distances, retrieve new image with np.argmin indices

import numpy as np
from PIL import Image
import requests

# get some image
im = Image.open(requests.get("https://upload.wikimedia.org/wikipedia/commons/thumb/7/77/Big_Nature_%28155420955%29.jpeg/800px-Big_Nature_%28155420955%29.jpeg", stream=True).raw)
newsize = (1000, 1000)
im = im.resize(newsize)
# im.show()
im = np.asarray(im)
new_shape = (im.shape[0],im.shape[1],1,3)

# Ignore above
# Now we have image of shape (1000,1000,1,3). 1 is there so its easy to subtract from color container
image = im.reshape(im.shape[0],im.shape[1],1,3)



# test colors
colors = [[0,0,0],[255,255,255],[0,0,255]]

# Create color container 
## It has same dimensions as image (1000,1000,number of colors,3)
colors_container = np.ones(shape=[image.shape[0],image.shape[1],len(colors),3])
for i,color in enumerate(colors):
    colors_container[:,:,i,:] = color



def closest(image,color_container):
    shape = image.shape[:2]
    total_shape = shape[0]*shape[1]

    # calculate distances
    ### shape =  (x,y,number of colors)
    distances = np.sqrt(np.sum((color_container-image)**2,axis=3))

    # get position of the smalles distance
    ## this means we look for color_container position ????-> (x,y,????,3)
    ### before min_index has shape (x,y), now shape = (x*y)
    #### reshaped_container shape = (x*y,number of colors,3)
    min_index = np.argmin(distances,axis=2).reshape(-1)
    # Natural index. Bind pixel position with color_position
    natural_index = np.arange(total_shape)

    # This is due to easy index access
    ## shape is (1000*1000,number of colors, 3)
    reshaped_container = colors_container.reshape(-1,len(colors),3)

    # Pass pixel position with corresponding position of smallest color
    color_view = reshaped_container[natural_index,min_index].reshape(shape[0],shape[1],3)
    return color_view

# NOTE: Dont pass uint8 due to overflow during subtract
result_image = closest(image,colors_container)

Image.fromarray(result_image.astype(np.uint8)).show()

Option 2:

build 256x256x256x3 size color cube based on your palette. In other words, for every existing color assign corresponding palette color that is closest. Save color cube (once/first time). Load color cube. Take image and use every color in image as index in color cube.

import numpy as np
from PIL import Image
import requests
import time
# get some image
im = Image.open(requests.get("https://helpx.adobe.com/content/dam/help/en/photoshop/using/convert-color-image-black-white/jcr_content/main-pars/before_and_after/image-before/Landscape-Color.jpg", stream=True).raw)
newsize = (1000, 1000)
im = im.resize(newsize)
im = np.asarray(im)


### Initialization: Do just once
# Step 1: Define palette
palette = np.array([[255,255,255],[125,0,0],[0,0,125],[0,0,0]])

# Step 2: Create/Load precalculated color cube
try:
    # for all colors (256*256*256) assign color from palette
    precalculated = np.load('view.npz')['color_cube']
except:
    precalculated = np.zeros(shape=[256,256,256,3])
    for i in range(256):
        print('processing',100*i/256)
        for j in range(256):
            for k in range(256):
                index = np.argmin(np.sqrt(np.sum(((palette)-np.array([i,j,k]))**2,axis=1)))
                precalculated[i,j,k] = palette[index]
    np.savez_compressed('view', color_cube = precalculated)
        

# Processing part
#### Step 1: Take precalculated color cube for defined palette and 

def get_view(color_cube,image):
    shape = image.shape[0:2]
    indices = image.reshape(-1,3)
    # pass image colors and retrieve corresponding palette color
    new_image = color_cube[indices[:,0],indices[:,1],indices[:,2]]
   
    return new_image.reshape(shape[0],shape[1],3).astype(np.uint8)

start = time.time()
result = get_view(precalculated,im)
print('Image processing: ',time.time()-start)
Image.fromarray(result).show()
Answered By: Martin

The task is to turn a picture into a palette version of it. You define a palette, and then you need to find, for every pixel, the nearest neighbor match in the defined palette for that pixel’s color. You get an index from that lookup, which you can then turn into the palette color for that pixel.

This is possible using FLANN (comes with OpenCV). It’s not much code either. The lookups take two seconds on my old computer.

One advantage of this approach is that it can handle "large" palettes without requiring lots of memory. This is not unique to FLANN however. What is unique to FLANN is probably how little (user-side) code it needs.

Disadvantage: this still takes a few seconds.

FLANN uses index structures and can handle arbitrary vectors, and it uses float32 types. Due to the index structures in FLANN, it performs sub-linearly (probably O(log(n)) or sth.), i.e. better than a "linear scan" (O(n)). However, the cost of FLANN’s complexity and generality would only be amortized by the better lookup complexity once the palette becomes huge. The "linear scan", with code specific to this problem, I present in another answer using numba.

Full notebook: https://gist.github.com/crackwitz/bbb1aff9b7c6c744665715a5337192c0

# set up FLANN
# somewhat arbitrary parameters because under-documented
norm = cv.NORM_L2
FLANN_INDEX_KDTREE = 1
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
search_params = dict(checks=50)
fm = cv.FlannBasedMatcher(index_params, search_params)

# make up a palette and give it to FLANN
levels = (0, 64, 128, 192, 255)
palette = np.uint8([
    [b,g,r]
    for b in levels
    for g in levels
    for r in levels
])
print("palette size:", len(palette))
fm.add(np.float32([palette])) # extra dimension is "pictures", unused
fm.train()

# find nearest neighbor matches for all pixels
queries = im.reshape((-1, 3)).astype(np.float32)
matches = fm.match(queries)

# get match indices and distances
assert len(palette) <= 256
indices = np.uint8([m.trainIdx for m in matches]).reshape(height, width)
dist = np.float32([m.distance for m in matches]).reshape(height, width)

# indices to palette colors
output = palette[indices]
# imshow(output)

lena in 125 colors or less

Answered By: Christoph Rackwitz

Here are two variants using numba, a JIT compiler for python code.

from numba import njit, prange

The first variant uses more numpy primitives (np.argmin) and hence "more" memory. Maybe the little bit of memory has an effect, or maybe numba calls numpy routines as is, without being able to optimize those.

@njit(parallel=True)
def lookup1(palette, im):
    palette = palette.astype(np.int32)
    (rows,cols) = im.shape[:2]
    result = np.zeros((rows, cols), dtype=np.uint8)
    
    for i in prange(rows):
        for j in range(cols):
            sqdists = ((im[i,j] - palette) ** 2).sum(axis=1)
            index = np.argmin(sqdists)
            result[i,j] = index

    return result

I get ~180-190 ms per run on lena.jpg and a palette of 125 colors.

The second variant uses more hand-written code to replace most of the numpy primitives, which makes it even faster.

@njit(parallel=True)
def lookup2(palette, im):
    (rows,cols) = im.shape[:2]
    result = np.zeros((rows, cols), dtype=np.uint8)
    
    for i in prange(rows): # parallelize over this
        for j in range(cols):
            pb,pg,pr = im[i,j] # take pixel apart
            bestindex = -1
            bestdist = 2**20
            for index in range(len(palette)):
                cb,cg,cr = palette[i] # take palette color apart
                dist = (pb-cb)**2 + (pg-cg)**2 + (pr-cr)**2
                if dist < bestdist:
                    bestdist = dist
                    bestindex = index
            
            result[i,j] = bestindex
    
    return result

30 ms per run!

I think that’s approaching the theoretical maximum to within an order of magnitude. I figure that from the required math operations.

  • per palette entry: A = 10 ops

    3 subtracts, 3 squares, 3 adds, 1 compare

  • per pixel: B = 1375 ops

    len(palette) * (A+1), one index increment

  • per row: C = 704512 ops

    ncols * (B+1), one index increment

  • per image: D = 360710656 ops

    nrows * (C+1), one index increment

So that, in 30 ms, on my ancient quadcore with hyperthreading, gives 12000 MIPS (I won’t say flop/s because no floating point). That means close to one instruction per cycle. I’m sure the code lacks some SIMD vectorization… one could investigate what LLVM thinks of these loops but I won’t bother with that now.

Some code in cython might be able to beat this because there you can tie down the types of variables even more.

The notebook: https://gist.github.com/crackwitz/208a1ed8ff470ad70ae41e2061111f02

Answered By: Christoph Rackwitz
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.